import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from src.modules.sonic import Sonic
import argparse
import os
from pathlib import Path
import matplotlib.pyplot as plt
from typing import Dict
from thop import profile


def to_bchw(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 4:
        return x
    if x.dim() == 5:
        B, C, D, H, W = x.shape
        if D != 1:
            raise ValueError(f"Expected D==1 for 5D input, got D={D}.")
        return x.squeeze(2)
    raise ValueError(f"Expected 4D or 5D tensor, got shape {tuple(x.shape)}.")
class SyntheticDataGenerator:

    @staticmethod
    def generate_single_multi_class_volume(volume_size, num_objects, noise_sigma, segmentation_criterion="shape"):
        D, H, W = volume_size
        vol_final = np.zeros((D, H, W, 3), dtype=np.float32)
        semantic_mask = np.zeros((D, H, W), dtype=np.int32)
        collision_mask = np.zeros((D, H, W), dtype=bool)
        zz, yy, xx = np.ogrid[:D, :H, :W]
        all_shapes = ["circle", "triangle", "square", "cross", "star"]
        shape_class_dict = {"circle": 1, "triangle": 2, "square": 3, "cross": 4, "star": 5}
        base_colors = [[1.0,0.0,0.0],[0.0,0.0,1.0],[0.0,1.0,0.0],[0.0,1.0,1.0],[1.0,0.5,0.0],[1.0,1.0,0.0],[1.0,0.0,1.0],[0.5,0.5,0.5]]
        shapes_to_place = (random.sample(all_shapes, k=min(num_objects, len(all_shapes))) if num_objects <= len(all_shapes) else random.choices(all_shapes, k=num_objects))

        for shape_label in shapes_to_place:
            for _ in range(100):
                current_shape_mask = np.zeros((D, H, W), dtype=bool)
                is_valid = False
                if shape_label == "circle":
                    r_min = max(1, min(H, W)//20)
                    r_max = min(H, W) // 5
                    if r_min >= r_max: r_max = r_min + 1
                    r = np.random.randint(r_min, r_max)
                    if H > 2 * r and W > 2 * r:
                        cy, cx = np.random.randint(r, H - r), np.random.randint(r, W - r)
                        current_shape_mask = ((yy - cy)**2 + (xx - cx)**2) <= r**2
                        is_valid = True
                elif shape_label == "square":
                    s_min = 3
                    s_max = min(H,W)//6 + 1
                    if s_min >= s_max: s_max = s_min + 1
                    s = np.random.randint(s_min, s_max)
                    if H > s and W > s:
                        y0,x0 = np.random.randint(0,H-s), np.random.randint(0,W-s)
                        current_shape_mask[:, y0:y0+s, x0:x0+s] = True
                        is_valid = True
                elif shape_label == "triangle":
                    base = np.random.randint(5, min(H, W)//5 + 1)
                    h = int(base * np.sqrt(3) / 2)
                    if h > 0 and H > h and W > base:
                        cy = np.random.randint(0, H - h)
                        cx = np.random.randint(base//2, W - base//2)
                        frac = (yy - cy) / float(h)
                        half = np.maximum((base / 2.0) * (1 - frac), 0)
                        triangle_mask = (np.abs(xx - cx) <= half) & (yy >= cy) & (yy < cy + h)
                        current_shape_mask = triangle_mask
                        is_valid = True
                elif shape_label == "cross":
                    a = np.random.randint(5, min(H, W)//5 + 1)
                    w = np.random.randint(max(1, a//4), a//2 + 1)
                    if H > 2*a and W > 2*a:
                        cy, cx = np.random.randint(a, H - a), np.random.randint(a, W - a)
                        horiz = (yy >= cy - w//2) & (yy < cy + w//2 + w % 2) & (xx >= cx - a) & (xx < cx + a)
                        vert = (xx >= cx - w//2) & (xx < cx + w//2 + w % 2) & (yy >= cy - a) & (yy < cy + a)
                        current_shape_mask = horiz | vert
                        is_valid = True
                elif shape_label == "star":
                    r = np.random.randint(5, min(H, W)//5 + 1)
                    if H > 2*r and W > 2*r:
                        cy, cx = np.random.randint(r, H - r), np.random.randint(r, W - r)
                        dist_yx = np.sqrt((yy - cy)**2 + (xx - cx)**2)
                        angle = np.arctan2(yy - cy, xx - cx)
                        inner_r = 0.4 * r
                        max_r = inner_r + (r - inner_r) * (np.cos(5 * angle) + 1) / 2
                        current_shape_mask = dist_yx <= max_r
                        is_valid = True
                
                if is_valid and not np.any(current_shape_mask & collision_mask):
                    chosen_base_color = random.choice(base_colors)
                    final_color = np.array(chosen_base_color) * np.random.uniform(0.7, 1.2) + np.random.uniform(-0.1, 0.1, 3)
                    vol_final[current_shape_mask] = final_color
                    semantic_mask[current_shape_mask] = shape_class_dict[shape_label]
                    collision_mask |= current_shape_mask
                    break
        vol_final += noise_sigma * np.random.randn(D, H, W, 3).astype(np.float32)
        return np.clip(vol_final, 0, 1.0), semantic_mask

    @staticmethod
    def generate_data(batch_size, volume_size, num_objects=5, noise_sigma=0.1, segmentation_criterion="shape"):
        D, H, W = volume_size
        x = torch.zeros(batch_size, 3, D, H, W)
        y_mask = torch.zeros(batch_size, 1, D, H, W)
        for i in range(batch_size):
            vol_np, mask_np = SyntheticDataGenerator.generate_single_multi_class_volume(
                volume_size, num_objects, noise_sigma, segmentation_criterion
            )
            x[i] = torch.from_numpy(vol_np.transpose(3, 0, 1, 2).astype(np.float32))
            y_mask[i, 0] = torch.from_numpy(mask_np.astype(np.float32))
        return x, y_mask

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=4):
        super().__init__()
        layers = [nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), nn.GroupNorm(8, out_channels), nn.GELU()]
        for _ in range(num_layers - 1):
            layers.extend([nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.GroupNorm(8, out_channels), nn.GELU()])
        self.block = nn.Sequential(*layers)

    def forward(self, x, **kwargs): 
        return self.block(x)

class ViTBlock(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim, depth, num_heads):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.img_size, self.patch_size = img_size, patch_size
        self.num_patches_H, self.num_patches_W = img_size // patch_size, img_size // patch_size
        num_patches = self.num_patches_H * self.num_patches_W
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.blocks = nn.ModuleList([TransformerBlock(dim=embed_dim, num_heads=num_heads) for _ in range(depth)])
        self.norm, self.embed_dim = nn.LayerNorm(embed_dim), embed_dim

    def forward(self, x, **kwargs):
        B, C, H, W = x.shape
        patch_embeddings = self.patch_embed(x).flatten(2).transpose(1, 2)
        if H != self.img_size or W != self.img_size:
            pos_embed_2d = self.pos_embed.transpose(1, 2).reshape(1, self.embed_dim, self.num_patches_H, self.num_patches_W)
            new_patch_H, new_patch_W = H // self.patch_size, W // self.patch_size
            pos_embed_resized = F.interpolate(pos_embed_2d, size=(new_patch_H, new_patch_W), mode='bicubic', align_corners=False)
            x = patch_embeddings + pos_embed_resized.flatten(2).transpose(1, 2)
        else:
            x = patch_embeddings + self.pos_embed
        for blk in self.blocks: x = blk(x)
        x = self.norm(x)
        H_patch, W_patch = H // self.patch_size, W // self.patch_size
        x = x.transpose(1, 2).reshape(B, self.embed_dim, H_patch, W_patch)
        return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads, head_dim = num_heads, dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv, self.proj = nn.Linear(dim, dim * 3), nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        x = (attn.softmax(dim=-1) @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features or in_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features or in_features, in_features)

    def forward(self, x): 
        return self.fc2(self.act(self.fc1(x)))

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim) 
        self.attn =  Attention(dim, num_heads=num_heads)
        self.norm2 = nn.LayerNorm(dim) 
        self.mlp = Mlp(dim, int(dim * 4.0))

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class SynthshapeModel(nn.Module):
    def __init__(self, model_type: str, H: int, W: int, in_channels: int = 3, num_classes: int = 6, K1: int = 32, K2: int = 128, **kwargs):
        super().__init__()
        self.model_type = model_type.lower()
        self.intermediate_mixer = nn.Sequential(nn.Conv2d(K1, K1, 1, bias=False), nn.GroupNorm(8, K1), nn.GELU(), nn.Conv2d(K1, K1, 1, bias=False), nn.GroupNorm(8, K1), nn.GELU())
        self.res_proj = nn.Conv2d(in_channels, K2, 1, bias=False)
        self.head = nn.Sequential(nn.Conv2d(K2, K2, 1, bias=False), nn.GroupNorm(8, K2), nn.GELU(), nn.Conv2d(K2, K2, 1, bias=False), nn.GroupNorm(8, K2), nn.GELU(), nn.Conv2d(K2, num_classes, 1, bias=True))
        if self.model_type == 'sonic':
            self.stage1 = Sonic(dim=2, in_channels=in_channels, num_hidden=K1, M_modes=kwargs.get('modes', 256))
            self.stage2 = Sonic(dim=2, in_channels=K1, num_hidden=K2, M_modes=kwargs.get('modes', 256))
        elif self.model_type == 'conv':
            c = kwargs.get('conv_layers', 4) // 2
            self.stage1 = ConvBlock(in_channels, K1, c)
            self.stage2 = ConvBlock(K1, K2, c)
        elif self.model_type == 'vit':
            v = kwargs.get('vit_depth', 4) // 2
            self.vit_in_proj = nn.Conv2d(in_channels, K1, 1)
            self.stage1 = ViTBlock(H, kwargs.get('patch_size', 8), K1, K1, v, kwargs.get('vit_heads', 4))
            self.stage2 = ViTBlock(H, kwargs.get('patch_size', 8), K1, K2, v, kwargs.get('vit_heads', 4))

    def forward(self, x, **kwargs):
        x = to_bchw(x)
        x_res = self.res_proj(x)
        stage_kwargs = kwargs if self.model_type == 'sonic' else {}
        h1 = self.stage1(self.vit_in_proj(x) if self.model_type == 'vit' else x, **stage_kwargs)
        h2 = self.stage2(self.intermediate_mixer(h1), **stage_kwargs)
        return self.head(h2 + x_res)


def train_model(args, model_type):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    device = "cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu")
    model = SynthshapeModel(**vars(args), model_type=model_type).to(device)
    print(f"Training model: {model_type.upper()} on {device.upper()}")
    with torch.no_grad():
        _, yb = SyntheticDataGenerator.generate_data(512, (1, args.H, args.W), 5, args.train_noise)
        hist = torch.bincount(to_bchw(yb).long().flatten(), minlength=args.num_classes)
        freq = hist.float() / hist.sum().clamp_min(1)
        cls_weights = (1.0 / freq.clamp_min(1e-6)).to(device)
        cls_weights = (cls_weights / cls_weights.mean()).to(device)
    ce_loss_fn = nn.CrossEntropyLoss(weight=cls_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, total_steps=args.epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=(device == 'cuda'))
    dice_log = []
    for ep in range(args.epochs):
        x, y = SyntheticDataGenerator.generate_data(args.bs, (1, args.H, args.W), 5, args.train_noise)
        x, y = x.to(device), y.to(device)
        with torch.amp.autocast(device_type=device, enabled=(device == 'cuda')):
            logits = model(x)
            preds = torch.argmax(logits, dim=1, keepdim=True)
            dice_score = multiclass_dice(preds, y, args.num_classes, 0)
            ce = ce_loss_fn(logits, to_bchw(y).squeeze(1).long())
            loss = 0.5 * ce + 0.5 * (1.0 - dice_score)
        dice_log.append(dice_score.item())
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        if (ep + 1) % args.log_every == 0:
            print(f"[{model_type.upper()}] Ep {ep+1:04d}/{args.epochs} | Loss: {loss.item():.4f} | Dice: {dice_score.item():.3f}")
    print(f"Finished training for {model_type.upper()}.")
    return model, dice_log

def multiclass_dice(preds, labels, num_classes, ignore_index=0, eps=1e-6):
    preds = preds.squeeze(1) if preds.dim() == 4 and preds.shape[1] == 1 else preds
    labels = to_bchw(labels).squeeze(1)
    preds_oh = F.one_hot(preds.long(), num_classes).permute(0, 3, 1, 2).float()
    labels_oh = F.one_hot(labels.long(), num_classes).permute(0, 3, 1, 2).float()
    if ignore_index is not None:
        preds_oh[:, ignore_index], labels_oh[:, ignore_index] = 0, 0
    inter = (preds_oh * labels_oh).sum(dim=[0, 2, 3])
    sums = preds_oh.sum(dim=[0, 2, 3]) + labels_oh.sum(dim=[0, 2, 3])
    valid = sums > 0
    if not valid.any(): return torch.tensor(1.0, device=preds.device)
    return ((2.0 * inter[valid] + eps) / (sums[valid] + eps)).mean()

def plot_comparison(results, out_path):
    plt.figure(figsize=(10, 6))
    for model_type, dice_log in results.items():
        plt.plot(dice_log, label=model_type.upper(), alpha=0.8)
    plt.title("Model Comparison: Dice Score During Training", fontsize=16)
    plt.xlabel("Epochs", fontsize=12)
    plt.ylabel("Mean Dice Score", fontsize=12)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.legend(fontsize=12)
    plt.ylim(0, 1.05)
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()
    print(f"\nTraining progress plot saved to: {out_path}")

def visualize_final_prediction(model, model_name, device, args):
    """Generates and saves a plot of a single test prediction."""
    print(f" Generating final prediction plot for {model_name.upper()}...")
    model.eval()
    
    torch.manual_seed(args.seed + 999)
    np.random.seed(args.seed + 999)
    random.seed(args.seed + 999)
    
    x, y = SyntheticDataGenerator.generate_data(1, (1, args.H, args.W), 5, args.train_noise)
    x, y = x.to(device), to_bchw(y).to(device)

    with torch.no_grad(), torch.amp.autocast(device_type=device, enabled=(device=='cuda')):
        logits = model(x)
        pred = torch.argmax(logits, dim=1)

    img_np = to_bchw(x).squeeze(0).permute(1, 2, 0).cpu().numpy()
    gt_np = y.squeeze(0).cpu().numpy()
    pred_np = pred.squeeze(0).cpu().numpy()

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f'Final Test Prediction: {model_name.upper()}', fontsize=16)

    axs[0].imshow(np.clip(img_np, 0, 1))
    axs[0].set_title('Test Input')
    axs[0].axis('off')

    axs[1].imshow(gt_np.squeeze(0), vmin=0, vmax=args.num_classes-1, cmap='tab20')
    axs[1].set_title('Ground Truth')
    axs[1].axis('off')

    axs[2].imshow(pred_np, vmin=0, vmax=args.num_classes-1, cmap='tab20')
    axs[2].set_title('Model Prediction')
    axs[2].axis('off')

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plot_path = Path(args.out) / f"final_prediction_{model_name}.png"
    plt.savefig(plot_path, dpi=200)
    plt.close(fig)
    print(f"   -> Saved to {plot_path}")

def run_inference_experiment(model: nn.Module, model_name: str, device: str, args: argparse.Namespace):
    print(f"\n Running inference experiment for model: {model_name.upper()}")
    model.eval()
    resolution_factors = [1.5, 1.0, 0.75]
    noise_levels = [args.train_noise, args.train_noise / 2, 0.0]
    fig, axs = plt.subplots(3, 4, figsize=(16, 12), gridspec_kw={'width_ratios': [1.5, 5, 5, 5]})
    fig.suptitle(f"Robustness Experiment: {model_name.upper()} Model", fontsize=18)
    torch.manual_seed(args.seed + 10)
    base_image_clean_5d, base_gt_5d = SyntheticDataGenerator.generate_data(1, (1, args.H, args.W), 5, 0.0)
    base_image_clean, base_gt = to_bchw(base_image_clean_5d).to(device), to_bchw(base_gt_5d).to(device)
    for row, res_factor in enumerate(resolution_factors):
        ax_gt = axs[row, 0]
        new_H_gt, new_W_gt = int(args.H * res_factor), int(args.W * res_factor)
        gt_resized = F.interpolate(base_gt.float(), size=(new_H_gt, new_W_gt), mode='nearest')
        ax_gt.imshow(gt_resized.squeeze().cpu().numpy(), vmin=0, vmax=args.num_classes-1, cmap='tab20')
        ax_gt.axis('off')
        row_label = f"Res: {new_H_gt}x{new_W_gt}"
        if row == 0:
            ax_gt.set_title(f"Ground Truth\n{row_label}", fontsize=12, pad=5)
        else:
            ax_gt.set_title(row_label, fontsize=12, pad=5)
        for col, noise_sigma in enumerate(noise_levels):
            ax = axs[row, col + 1]
            new_H, new_W = int(args.H * res_factor), int(args.W * res_factor)
            input_img_resized = F.interpolate(base_image_clean, (new_H, new_W), mode='bilinear', align_corners=False)
            input_img_noisy = torch.clamp(input_img_resized + torch.randn_like(input_img_resized) * noise_sigma, 0, 1)
            forward_kwargs = {'dx': 1/res_factor, 'dy': 1/res_factor} if model.model_type == 'sonic' else {}
            with torch.no_grad(), torch.amp.autocast(device_type=device, enabled=(device=='cuda')):
                logits = model(input_img_noisy, **forward_kwargs)
            pred = torch.argmax(logits, dim=1, keepdim=True)
            if model.model_type == 'vit':
                gt_for_dice = F.interpolate(base_gt.float(), size=(new_H, new_W), mode='nearest').long()
                dice = multiclass_dice(pred, gt_for_dice, args.num_classes, 0).item()
                img_to_display = pred.squeeze().cpu().numpy()
            else:
                pred_resized_orig = F.interpolate(pred.float(), (args.H, args.W), mode='nearest').long()
                dice = multiclass_dice(pred_resized_orig, base_gt, args.num_classes, 0).item()
                img_to_display = pred_resized_orig.squeeze().cpu().numpy()
            ax.imshow(img_to_display, vmin=0, vmax=args.num_classes-1, cmap='tab20')
            ax.axis('off')
            ax.set_title(f"Dice: {dice:.3f}", fontsize=11)
    cols = [f"Noise σ={n:.2f}" for n in noise_levels]
    for col_idx, col_title in enumerate(cols):
        ax = axs[0, col_idx + 1]
        current_title = ax.get_title()
        ax.set_title(f"{col_title}\n{current_title}", fontsize=12, pad=5)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plot_path = Path(args.out) / f"experiment_grid_{model_name}.png"
    plt.savefig(plot_path, dpi=200)
    plt.close(fig)
    print(f"Saved experiment plot to: {plot_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Geosynth Training & Inference Experiment Script")
    parser.add_argument("--H", type=int, default=64)
    parser.add_argument("--W", type=int, default=64)
    parser.add_argument("--num_classes", type=int, default=6)
    parser.add_argument("--K1", type=int, default=64)
    parser.add_argument("--K2", type=int, default=64)
    parser.add_argument("--modes", type=int, default=64)
    parser.add_argument("--patch_size", type=int, default=8)
    parser.add_argument("--vit_depth", type=int, default=4)
    parser.add_argument("--vit_heads", type=int, default=4)
    parser.add_argument("--conv_layers", type=int, default=4)
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--bs", type=int, default=24)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--wd", type=float, default=1e-3)
    parser.add_argument("--train_noise", type=float, default=0.)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--cpu", action="store_true")
    parser.add_argument("--out", type=str, default="./results/synthshape_results")
    parser.add_argument("--log_every", type=int, default=250)
    args = parser.parse_args()
    
    Path(args.out).mkdir(parents=True, exist_ok=True)
    
    models_to_run = ['sonic', 'conv']
    trained_models: Dict[str, nn.Module] = {}
    training_results: Dict[str, list] = {}
    model_stats: Dict[str, dict] = {} 
    device = "cpu" if args.cpu else ("cuda" if torch.cuda.is_available() else "cpu")

    for model_name in models_to_run:
        model, dice_log = train_model(args, model_name)
        trained_models[model_name] = model
        training_results[model_name] = dice_log

        model.eval()
        with torch.no_grad():
            final_dice = dice_log[-1] if dice_log else 0.0
            params_m = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
            dummy_input = torch.randn(1, 3, args.H, args.W).to(device)
            macs, _ = profile(model, inputs=(dummy_input,), verbose=False)
            gflops = (macs * 2) / 1_000_000_000
            
            starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
            timings = []
            for _ in range(20):
                model(dummy_input)
            for _ in range(100):
                starter.record()
                model(dummy_input)
                ender.record()
                torch.cuda.synchronize()
                timings.append(starter.elapsed_time(ender))
            avg_time_ms = np.mean(timings)

            model_stats[model_name] = {
                "Final Dice": final_dice,
                "Params (M)": params_m,
                "GFLOPs": gflops,
                "Inference Time (ms)": avg_time_ms
            }

    print("\n" + "="*50)
    print(" " * 15 + "MODEL COMPARISON STATS")
    print("="*50)
    for name, stats in model_stats.items():
        print(f"▶ MODEL: {name.upper()}")
        for key, value in stats.items():
            print(f"  - {key:<20}: {value:.3f}")
        print("-"*50)

    comparison_plot_path = Path(args.out) / "comparison_curves.png"
    plot_comparison(training_results, comparison_plot_path)

    print("\n" + "="*50)
    print(" " * 12 + "GENERATING FINAL PREDICTION PLOTS")
    print("="*50)
    for model_name, model in trained_models.items():
        visualize_final_prediction(model, model_name, device, args)

    # --- 5. Run the inference experiment on each trained model ---
    for model_name, model in trained_models.items():
        run_inference_experiment(model, model_name, device, args)

