"""
Main training script for the image-based SDE model.
This version is adapted for local execution, saving models and plots
to a local directory instead of using MLflow.
"""
import torch
import numpy as np
import yaml
import matplotlib.pyplot as plt
from pathlib import Path
from types import SimpleNamespace
import argparse
import random
import json
from datetime import datetime
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Import project modules
from data import JumpDataLoader, generate_and_save_data, DataConfig, random_subsample
from network import UNetModel
from alternative_network import UNet
from sde_utils import loss_calc_sde
from ema import EMA

def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def apply_overrides(config, overrides):
    """Recursively applies overrides to the config object."""
    if not overrides:
        return
    for override in overrides:
        keys, value = override.split('=')
        attrs = keys.split('.')
        
        # Navigate to the correct nested namespace
        base = config
        for key in attrs[:-1]:
            if not hasattr(base, key):
                print(f"Warning: Override key '{keys}' contains non-existent section '{key}'.")
                base = None
                break
            base = getattr(base, key)
        
        if base is None:
            continue
            
        # Set the final value, attempting to cast it to the correct type
        attr_name = attrs[-1]
        if hasattr(base, attr_name):
            original_value = getattr(base, attr_name)
            try:
                setattr(base, attr_name, type(original_value)(value))
            except (ValueError, TypeError):
                print(f"Warning: Could not cast override value '{value}' for key '{keys}'. Using as string.")
                setattr(base, attr_name, value)
        else:
            print(f"Warning: Override key '{keys}' refers to non-existent parameter '{attr_name}'.")
def load_config(config_path: str) -> SimpleNamespace:
    with open(config_path, 'r') as f:
        config_dict = yaml.safe_load(f)
    def dict_to_namespace(d):
        if not isinstance(d, dict): return d
        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
    return dict_to_namespace(config_dict)

def save_config(config, path):
    """Saves the config namespace to a JSON file."""
    def namespace_to_dict(n):
        if not isinstance(n, SimpleNamespace):
            if isinstance(n, tuple): return list(n)
            return n
        return {k: namespace_to_dict(v) for k, v in n.__dict__.items()}
    with open(path, 'w') as f:
        json.dump(namespace_to_dict(config), f, indent=4)

def calculate_val_loss(net, val_set_cpu, config, device):
    net.eval()
    total_loss = 0
    num_batches = 0
    val_set = val_set_cpu.to(device)
    full_val_times_1d = torch.linspace(0., config.data.no_timesteps - 1, config.data.no_timesteps, device=device)
    full_val_times = full_val_times_1d.view(1, -1, 1, 1, 1).expand(val_set.shape[0], -1, 1, config.data.image_size, config.data.image_size)
    subsampled_val = random_subsample(val_set, full_val_times, config.data.subsample_time, random_seed=config.seed)
    val_dataloader = torch.utils.data.DataLoader(subsampled_val, batch_size=config.training.batch_size, shuffle=False)
    with torch.no_grad():
        for batch in val_dataloader:
            data, time = torch.chunk(batch, 2, dim=2)
            loss = loss_calc_sde(
                data, time, net, config.model.memory_length,
                sigma=config.model.sigma, rho=config.model.rho, device=device
            )
            total_loss += loss.item()
            num_batches += 1
    return total_loss / max(1, num_batches)

def visualize_trajectories(trajectories, output_path, title):
    """Visualizes a grid of trajectories and saves it to a file."""
    fig, ax = plt.subplots(4, 10, figsize=(20, 8))
    for i in range(min(4, trajectories.shape[0])):
        for j in range(min(10, trajectories.shape[1])):
            ax[i, j].imshow(trajectories[i, j, 0].cpu(), cmap="gray")
            ax[i, j].axis("off")
    fig.suptitle(title, fontsize=16)
    fig.savefig(output_path)
    plt.close(fig)

def train(net, optimizer, scheduler, ema, dataloader, val_set_cpu, config, device, output_dir):
    best_val_loss = float('inf')
    for epoch in range(config.training.no_epochs):
        net.train()
        avg_loss = 0
        for i, batch in enumerate(dataloader):
            data, time = torch.chunk(batch.to(device), 2, dim=2)
            loss = loss_calc_sde(
                data, time, net, config.model.memory_length,
                sigma=config.model.sigma, rho=config.model.rho,
                device=device,
            )
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
            optimizer.step()
            ema.update(net)
            avg_loss = (i * avg_loss + loss.item()) / (i + 1)
        
        backup = {name: param.data.clone() for name, param in net.named_parameters()}
        ema.apply_shadow(net)
        avg_val_loss = calculate_val_loss(net, val_set_cpu, config, device)
        ema.restore(net, backup)
        
        scheduler.step(avg_val_loss)
        print(f"Epoch {epoch+1}/{config.training.no_epochs} | Avg Train Loss: {avg_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f"  -> New best model saved with Val Loss: {best_val_loss:.4f}")
            backup = {name: param.data.clone() for name, param in net.named_parameters()}
            ema.apply_shadow(net)
            torch.save(net.state_dict(), output_dir / "best_model.pt")
            ema.restore(net, backup)
            
    return best_val_loss

def main():
    # ... (parser and config loading are the same) ...
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config.yaml', help='Path to the YAML configuration file.')
     # --- New argument for command-line overrides ---
    parser.add_argument(
        '--override',
        nargs='*',
        help='Override config params, e.g., --override model.memory_length=8 training.lr=0.0003'
    )

    args = parser.parse_args()
    config = load_config(args.config)

    # --- Apply overrides from command line ---
    apply_overrides(config, args.override)
    set_seed(config.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # --- Setup Local Run Directory ---
    base_run_id = datetime.now().strftime("%m%d_%H%M")
    override_parts = []
    if args.override:
        for override in args.override:
            # Sanitize the override string to be filename-friendly
            part = override.replace("=", "-").replace(".", "_")
            override_parts.append(part)

    override_suffix = "_".join(override_parts)

    if override_suffix:
        run_id = f"{base_run_id}_{override_suffix}"
    else:
        run_id = base_run_id
    output_dir = Path(f"local_runs/sde_model/{run_id}")
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Starting Run: {run_id}. Outputs will be saved to: {output_dir}")

    save_config(config, output_dir / "config.json")

    generate_and_save_data(DataConfig(
        image_size=config.data.image_size,
        no_timesteps=config.data.no_timesteps,
        generation_batch_size=config.data.generation_batch_size,
        cube_size=config.data.cube_size,jump_prob=config.data.jump_prob
    ))
    
    data_loader = JumpDataLoader(data_dir=Path("data"), seed=config.seed, image_size=config.data.image_size,jump_prob=config.data.jump_prob)
    dataloader, val_set_cpu, _, _ = data_loader.get_data(
        config.data.subsample_time, config.data.no_timesteps,
        batch_size=config.training.batch_size, device="cpu",
        equidist=config.data.equidist
    )
    visualize_trajectories(val_set_cpu, output_dir / "ground_truth_trajectories.png", "Ground Truth Trajectories")
    
    if config.model.network == "oai":
        net = UNetModel(
        image_size=config.data.image_size,
        in_channels=config.model.memory_length + 1,
        model_channels=config.model.model_channels,
        out_channels=1,
        num_res_blocks=config.model.num_res_blocks,
        attention_resolutions=tuple(config.model.attention_resolutions),
        channel_mult=tuple(config.model.channel_mult),
        dropout=config.training.dropout,
        use_scale_shift_norm=True
    ).to(device)
    elif config.model.network == "simple":
        net = UNet(in_channels=config.model.memory_length + 1, 
            out_channels=1, 
            base_channels=config.model.unet_base_channels
        ).to(device)
    
    optimizer = torch.optim.Adam(net.parameters(), lr=config.training.lr)
    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3, verbose=True)
    
    ema = EMA(net, decay=config.training.ema_decay)

    print("Starting training...")
    train(net, optimizer, scheduler, ema, dataloader, val_set_cpu, config, device, output_dir)
    print(f"\nTraining finished. Best model saved in {output_dir}")

if __name__ == "__main__":
    main()

