"""
Main training script for the image-based jump-diffusion model.
This version is adapted for local execution, saving models and plots
to a local directory instead of using MLflow.
"""
import torch
import numpy as np
import yaml
import matplotlib.pyplot as plt
from pathlib import Path
from types import SimpleNamespace
import argparse
import random
import json
from datetime import datetime
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Import project modules
from data import JumpDataLoader, generate_and_save_data, DataConfig, random_subsample

from network import UNetModel
from alternative_network import UNet
from jump_utils import loss_calc_jump
from ema import EMA

def apply_overrides(config, overrides):
    """Recursively applies overrides to the config object."""
    if not overrides:
        return
    for override in overrides:
        keys, value = override.split('=')
        attrs = keys.split('.')
        
        # Navigate to the correct nested namespace
        base = config
        for key in attrs[:-1]:
            if not hasattr(base, key):
                print(f"Warning: Override key '{keys}' contains non-existent section '{key}'.")
                base = None
                break
            base = getattr(base, key)
        
        if base is None:
            continue
            
        # Set the final value, attempting to cast it to the correct type
        attr_name = attrs[-1]
        if hasattr(base, attr_name):
            original_value = getattr(base, attr_name)
            try:
                setattr(base, attr_name, type(original_value)(value))
            except (ValueError, TypeError):
                print(f"Warning: Could not cast override value '{value}' for key '{keys}'. Using as string.")
                setattr(base, attr_name, value)
        else:
            print(f"Warning: Override key '{keys}' refers to non-existent parameter '{attr_name}'.")
def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def load_config(config_path: str) -> SimpleNamespace:
    with open(config_path, 'r') as f:
        config_dict = yaml.safe_load(f)
    def dict_to_namespace(d):
        if not isinstance(d, dict): return d
        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
    return dict_to_namespace(config_dict)

def save_config(config, path):
    """Saves the config namespace to a JSON file."""
    def namespace_to_dict(n):
        if not isinstance(n, SimpleNamespace):
            return n
        # Convert tuples (like attention_resolutions) to lists for JSON compatibility
        if isinstance(n, tuple):
            return list(n)
        if isinstance(n, SimpleNamespace):
            return {k: namespace_to_dict(v) for k, v in n.__dict__.items()}
        return n
        
    with open(path, 'w') as f:
        json.dump(namespace_to_dict(config), f, indent=4)


def calculate_val_loss(net, val_set_cpu, config, device):
    net.eval()
    total_loss = 0
    num_batches = 0
    val_set = val_set_cpu.to(device)
    full_val_times_1d = torch.linspace(0., config.data.no_timesteps - 1, config.data.no_timesteps, device=device)
    full_val_times = full_val_times_1d.view(1, -1, 1, 1, 1).expand(val_set.shape[0], -1, 1, config.data.image_size, config.data.image_size)
    subsampled_val = random_subsample(val_set, full_val_times, config.data.subsample_time, random_seed=config.seed)
    val_dataloader = torch.utils.data.DataLoader(subsampled_val, batch_size=config.training.batch_size, shuffle=False)
    with torch.no_grad():
        for batch in val_dataloader:
            data, time = torch.chunk(batch, 2, dim=2)
            loss = loss_calc_jump(
                data, time, net, config.model.memory_length,
                sigma=config.model.sigma, rho=config.model.rho,
                loss_function=config.training.loss_function, device=device
            )
            total_loss += loss.item()
            num_batches += 1
    return total_loss / max(1, num_batches)

def visualize_trajectories(trajectories, output_path, title):
    """Visualizes a grid of trajectories and saves it to a file."""
    fig, ax = plt.subplots(4, 10, figsize=(20, 8))
    for i in range(min(4, trajectories.shape[0])):
        for j in range(min(10, trajectories.shape[1])):
            ax[i, j].imshow(trajectories[i, j, 0].cpu(), cmap="gray")
            ax[i, j].axis("off")
    fig.suptitle(title, fontsize=16)
    fig.savefig(output_path)
    plt.close(fig)

def euler_bridge_sampler(net, a, b, t1_val, t2_val, memory, config, device):
    a = a.squeeze(1)
    b = b.squeeze(1)
    no_samples = 1
    image_size = config.data.image_size
    disc_steps = config.model.disc_steps
    x = a.clone()
    bridge_traj = torch.zeros((no_samples, disc_steps, 1, image_size, image_size), device=device)
    t2 = t2_val * torch.ones(no_samples, device=device)
    mem_reshaped = memory.reshape(no_samples, -1, image_size, image_size)
    for i in range(disc_steps):
        bridge_traj[:, i] = x.unsqueeze(1)
        t_val = t1_val + (i / disc_steps) * (t2_val - t1_val)
        t = t_val * torch.ones(no_samples, device=device)
        h = (t2_val - t1_val) / disc_steps
        time_in = torch.cat([t2, t], dim=0)
        inpu = torch.cat([x, mem_reshaped], dim=1)
        with torch.no_grad():
            out = net(inpu, time_in)
        lambda_net_raw, mean_net, sig_net_raw = out.split(1, dim=1)
        lambda_net, sig_net = torch.exp(lambda_net_raw), torch.exp(sig_net_raw)
        z = mean_net + sig_net * torch.randn_like(mean_net)
        rt = torch.exp(-lambda_net * h).clamp(0, 1)
        m = torch.bernoulli(1 - rt)
        x = (1 - m) * x + m * z
    return bridge_traj

def visualize_bridge_comparison(memory, gt_bridge, gen_bridge, a, b, epoch, idx, output_dir):
    mem_len = memory.shape[1]
    bridge_len = gen_bridge.shape[1]
    total_len = mem_len + 1 + bridge_len + 1
    fig, axes = plt.subplots(2, total_len, figsize=(total_len * 1.5, 3.5))
    fig.suptitle(f"Bridge Visualization - Epoch {epoch}", fontsize=16)
    axes[0, 0].set_ylabel("Ground Truth", fontsize=12)
    for i in range(mem_len):
        axes[0, i].imshow(memory[0, i, 0].cpu(), cmap="gray"); axes[0, i].set_title(f"Mem {i}")
    axes[0, mem_len].imshow(a[0, 0, 0].cpu(), cmap="gray"); axes[0, mem_len].set_title("Start (t1)")
    for i in range(bridge_len):
        if i < gt_bridge.shape[1]: axes[0, mem_len + 1 + i].imshow(gt_bridge[0, i, 0].cpu(), cmap="gray")
    axes[0, -1].imshow(b[0, 0, 0].cpu(), cmap="gray"); axes[0, -1].set_title("End (t2)")
    axes[1, 0].set_ylabel("Generated", fontsize=12)
    for i in range(mem_len):
        axes[1, i].imshow(memory[0, i, 0].cpu(), cmap="gray")
    axes[1, mem_len].imshow(a[0, 0, 0].cpu(), cmap="gray")
    for i in range(bridge_len):
        axes[1, mem_len + 1 + i].imshow(gen_bridge[0, i, 0].cpu(), cmap="gray")
    axes[1, -1].imshow(b[0, 0, 0].cpu(), cmap="gray")
    for ax in axes.flatten(): ax.axis("off")
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(output_dir / f"bridge_comparison_epoch_{epoch}_sample_{idx}.png")
    plt.close(fig)

def train(net, optimizer, scheduler, ema, dataloader, val_set_cpu, train_data_full_cpu, config, device, output_dir):
    best_val_loss = float('inf')
    for epoch in range(config.training.no_epochs):
        net.train()
        avg_loss = 0
        for i, batch in enumerate(dataloader):
            data, time = torch.chunk(batch.to(device), 2, dim=2)
            loss = loss_calc_jump(
                data, time, net, config.model.memory_length,
                sigma=config.model.sigma, rho=config.model.rho,
                loss_function=config.training.loss_function,
                device=device,
            )
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
            optimizer.step()
            ema.update(net)
            avg_loss = (i * avg_loss + loss.item()) / (i + 1)
        
        backup = {name: param.data.clone() for name, param in net.named_parameters()}
        ema.apply_shadow(net)
        avg_val_loss = calculate_val_loss(net, val_set_cpu, config, device)
        ema.restore(net, backup)
        
        scheduler.step(avg_val_loss)
        print(f"Epoch {epoch+1}/{config.training.no_epochs} | Avg Train Loss: {avg_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            
            print(f"  -> New best model saved with Val Loss: {best_val_loss:.4f}")
            backup = {name: param.data.clone() for name, param in net.named_parameters()}
            ema.apply_shadow(net)
            # --- Save model locally ---
            torch.save(net.state_dict(), output_dir / "best_model.pt")
            ema.restore(net, backup)

            #if train_data_full_cpu is not None:
            #    print("  -> Generating bridge visualization...")
            #    backup = {name: param.data.clone() for name, param in net.named_parameters()}
            #    ema.apply_shadow(net)
            #    num_bridge_samples = 3
            #    for i in range(num_bridge_samples):
            #        traj_idx = random.randint(0, train_data_full_cpu.shape[0] - 1)
            #        full_traj = train_data_full_cpu[traj_idx:traj_idx+1].to(device)
            #        t1_idx = random.randint(config.model.memory_length, full_traj.shape[1] - 3)
            #        t2_idx = t1_idx + 1 
            #        a, b = full_traj[:, t1_idx:t1_idx+1], full_traj[:, t2_idx:t2_idx+1]
            #        gt_bridge = full_traj[:, t1_idx+1:t2_idx]
            #        memory = full_traj[:, t1_idx - config.model.memory_length+1 : t1_idx+1]
            #        gen_bridge = euler_bridge_sampler(net, a, b, t1_idx, t2_idx, memory, config, device)
            #        visualize_bridge_comparison(memory, gt_bridge, gen_bridge, a, b, epoch + 1, i, output_dir)
            #    ema.restore(net, backup)
            
    return best_val_loss

def main():
    # ... (parser and config loading are the same) ...
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config.yaml', help='Path to the YAML configuration file.')
     # --- New argument for command-line overrides ---
    parser.add_argument(
        '--override',
        nargs='*',
        help='Override config params, e.g., --override model.memory_length=8 training.lr=0.0003'
    )

    args = parser.parse_args()
    config = load_config(args.config)

    # --- Apply overrides from command line ---
    apply_overrides(config, args.override)
    set_seed(config.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # --- Setup Local Run Directory ---
    base_run_id = datetime.now().strftime("%m%d_%H%M")
    override_parts = []
    if args.override:
        for override in args.override:
            # Sanitize the override string to be filename-friendly
            part = override.replace("=", "-").replace(".", "_")
            override_parts.append(part)

    override_suffix = "_".join(override_parts)

    if override_suffix:
        run_id = f"{base_run_id}_{override_suffix}"
    else:
        run_id = base_run_id

    output_dir = Path(f"local_runs/jump_model/{run_id}")
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Starting Run: {run_id}. Outputs will be saved to: {output_dir}")
    
    # Save the config file for this run
    save_config(config, output_dir / "config.json")
    
    generate_and_save_data(DataConfig(
        image_size=config.data.image_size,
        no_timesteps=config.data.no_timesteps,
        generation_batch_size=config.data.generation_batch_size,
        cube_size=config.data.cube_size,jump_prob=config.data.jump_prob
    ))

    data_loader = JumpDataLoader(data_dir=Path("data"), seed=config.seed, image_size=config.data.image_size, jump_prob=config.data.jump_prob)
    dataloader, val_set_cpu, _, _ = data_loader.get_data(
        config.data.subsample_time, config.data.no_timesteps,
        batch_size=config.training.batch_size, device="cpu",
        equidist=config.data.equidist
    )
    visualize_trajectories(val_set_cpu, output_dir / "ground_truth_trajectories.png", "Ground Truth Trajectories")
    
    train_data_path = Path("data") / f"train_data_size{config.data.image_size}_jump_prob{config.data.jump_prob}.pt"
    train_data_full_cpu = torch.load(train_data_path) if train_data_path.exists() else None
    
    if train_data_full_cpu is None:
        print("WARNING: Full training data not found, skipping bridge visualizations.")
    
    if config.model.network == "oai":
        net = UNetModel(
        image_size=config.data.image_size,
        in_channels=config.model.memory_length + 1,
        model_channels=config.model.model_channels,
        out_channels=3,
        num_res_blocks=config.model.num_res_blocks,
        attention_resolutions=tuple(config.model.attention_resolutions),
        channel_mult=tuple(config.model.channel_mult),
        dropout=config.training.dropout,
        use_scale_shift_norm=True
    ).to(device)
    elif config.model.network == "simple":
        net = UNet(in_channels=config.model.memory_length + 1, 
            out_channels=3, 
            base_channels=config.model.unet_base_channels
        ).to(device)
    
    optimizer = torch.optim.Adam(net.parameters(), lr=config.training.lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True)
    
    ema = EMA(net, decay=config.training.ema_decay)

    print("Starting training...")
    train(net, optimizer, scheduler, ema, dataloader, val_set_cpu, train_data_full_cpu, config, device, output_dir)
    print(f"\nTraining finished. Best model saved in {output_dir}")

if __name__ == "__main__":
    main()


