import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from matplotlib.colors import ListedColormap
model_path = "saved_model/3D_unet_100_epochs_2_batch_patch_training.keras"
model = load_model(model_path, compile=False)
DATA_ROOT = "glioma split data"
test_img_dir = os.path.join(DATA_ROOT, "test/images/")
test_mask_dir = os.path.join(DATA_ROOT, "test/masks/")
test_img_list = sorted([f for f in os.listdir(test_img_dir) if f.endswith('.npy')])
test_mask_list = sorted([f for f in os.listdir(test_mask_dir) if f.endswith('.npy')])
sample_idx = 1 # Change this to visualize different samples
test_image = np.load(os.path.join(test_img_dir, test_img_list[sample_idx]))
test_mask = np.load(os.path.join(test_mask_dir, test_mask_list[sample_idx]))
# Define patch size (must match model input size)
patch_size = (96, 96, 96)
# Function to extract center patch
def extract_center_patch(volume, patch_size):
"""Extract center patch from a 3D volume"""
start_x = (volume.shape[0] - patch_size[0]) // 2
start_y = (volume.shape[1] - patch_size[1]) // 2
start_z = (volume.shape[2] - patch_size[2]) // 2
start_x:start_x + patch_size[0],
start_y:start_y + patch_size[1],
start_z:start_z + patch_size[2]
# Extract center patches - use full brain instead of patches for better visualization
# Only use patches for model prediction
image_patch = extract_center_patch(test_image, patch_size)
mask_patch = extract_center_patch(test_mask, patch_size)
# For visualization, use the full brain volumes
input_image = np.expand_dims(image_patch, axis=0)
prediction = model.predict(input_image)
predicted_mask = np.argmax(prediction, axis=-1)[0]
# Convert ground truth to class indices if it's one-hot encoded
if len(mask_patch.shape) == 4:
mask_patch = np.argmax(mask_patch, axis=-1)
if len(viz_mask.shape) == 4:
viz_mask = np.argmax(viz_mask, axis=-1)
# Select specific slices to visualize
slice_numbers = [50, 75, 90]
# Create a colormap for the masks
colors = ['black', 'red', 'green', 'blue'] # 0: background, 1: NETC, 2: SNFH, 3: ET
cmap = ListedColormap(colors)
# Create the figure - adjusted for better text visibility
fig, axes = plt.subplots(3, 4, figsize=(16, 12),
gridspec_kw={'wspace': 0.1, 'hspace': 0.15})
fig.suptitle(f'Brain Tumor Segmentation Results - Sample: {test_img_list[sample_idx]}',
fontsize=16, y=0.98, weight='bold')
column_titles = ['Original Image', 'Ground Truth', 'Predicted Mask', 'Difference Map']
for j, title in enumerate(column_titles):
axes[0, j].text(0.5, 1.15, title, transform=axes[0, j].transAxes,
fontsize=14, weight='bold', ha='center', va='bottom')
for i, slice_num in enumerate(slice_numbers):
# Ensure slice number is within valid range
if slice_num >= viz_image.shape[2] or slice_num < 0:
slice_num = max(0, min(slice_num, viz_image.shape[2] - 1))
# Get the slices from full brain
img_slice = viz_image[:, :, slice_num, 0] if len(viz_image.shape) == 4 else viz_image[:, :, slice_num]
gt_slice = viz_mask[:, :, slice_num]
# For predicted mask, we need to map patch coordinates to full brain coordinates
# Create a full-size prediction mask filled with zeros
full_pred_mask = np.zeros(viz_image.shape[:3])
# Calculate patch position in full brain
start_x = (viz_image.shape[0] - patch_size[0]) // 2
start_y = (viz_image.shape[1] - patch_size[1]) // 2
start_z = (viz_image.shape[2] - patch_size[2]) // 2
# Place the predicted patch in the full brain
start_x:start_x + patch_size[0],
start_y:start_y + patch_size[1],
start_z:start_z + patch_size[2]
pred_slice = full_pred_mask[:, :, slice_num]
# Rotate for proper orientation
img_slice = np.rot90(img_slice)
gt_slice = np.rot90(gt_slice)
pred_slice = np.rot90(pred_slice)
# Normalize image for better contrast
img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min())
# Apply contrast enhancement
img_slice = np.power(img_slice, 0.7) # Gamma correction for better visibility
# Create difference map (ground truth - predicted)
diff_slice = gt_slice.astype(float) - pred_slice.astype(float)
axes[i, 0].text(-0.15, 0.5, f'Slice {slice_num}', transform=axes[i, 0].transAxes,
fontsize=12, weight='bold', rotation=90, va='center', ha='center')
# Column 1: Original Image with enhanced contrast
axes[i, 0].imshow(img_slice, cmap='gray', vmin=0, vmax=1)
axes[i, 0].set_xticks([])
axes[i, 0].set_yticks([])
axes[i, 0].set_aspect('equal')
# Column 2: Ground Truth overlay
axes[i, 1].imshow(img_slice, cmap='gray', vmin=0, vmax=1)
gt_display = axes[i, 1].imshow(gt_slice, cmap=cmap, vmin=0, vmax=3, alpha=0.6)
axes[i, 1].set_xticks([])
axes[i, 1].set_yticks([])
axes[i, 1].set_aspect('equal')
# Column 3: Predicted Mask overlay
axes[i, 2].imshow(img_slice, cmap='gray', vmin=0, vmax=1)
pred_display = axes[i, 2].imshow(pred_slice, cmap=cmap, vmin=0, vmax=3, alpha=0.6)
axes[i, 2].set_xticks([])
axes[i, 2].set_yticks([])
axes[i, 2].set_aspect('equal')
# Column 4: Difference Map (NEW)
diff_cmap = ListedColormap(['blue', 'black', 'white', 'red']) # -3 to +3 range
axes[i, 3].imshow(img_slice, cmap='gray', vmin=0, vmax=1)
diff_display = axes[i, 3].imshow(diff_slice, cmap='RdBu_r', vmin=-3, vmax=3, alpha=0.8)
axes[i, 3].set_xticks([])
axes[i, 3].set_yticks([])
axes[i, 3].set_aspect('equal')
# Add colorbars with better positioning and larger text
# Main colorbar for GT and Pred
cbar_ax1 = fig.add_axes([0.25, 0.02, 0.2, 0.03]) # Horizontal colorbar
cbar1 = fig.colorbar(gt_display, cax=cbar_ax1, orientation='horizontal', ticks=[0, 1, 2, 3])
cbar1.ax.set_xticklabels(['Background', 'NETC', 'SNFH', 'ET'], fontsize=12)
cbar1.set_label('Tumor Regions', fontsize=12, weight='bold')
# Difference map colorbar
cbar_ax2 = fig.add_axes([0.55, 0.02, 0.2, 0.03]) # Horizontal colorbar
cbar2 = fig.colorbar(diff_display, cax=cbar_ax2, orientation='horizontal')
cbar2.set_label('Difference (GT - Pred)', fontsize=12, weight='bold')
cbar2.ax.tick_params(labelsize=10)
# Add text annotations for better understanding
fig.text(0.02, 0.5, 'Axial Slices', rotation=90, fontsize=14, weight='bold',
va='center', ha='center')
# Add legend for difference map
fig.text(0.82, 0.08, legend_text, fontsize=10, va='bottom', ha='left',
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8))
plt.tight_layout(rect=[0.05, 0.08, 0.95, 0.95])
print(f"\nVisualization Statistics:")
print(f"Original brain shape: {viz_image.shape}")
print(f"Patch size used for prediction: {patch_size}")
print(f"Slices displayed: {slice_numbers}")
print(f"Ground truth classes present: {np.unique(viz_mask)}")
print(f"Predicted classes present: {np.unique(predicted_mask)}")