import os
import time
import yaml
from tqdm import tqdm
import numpy as np

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader

from src.scalegmn.models import ScaleGMN_equiv
from src.data.cifar10_dataset import MLPGraphDataset
from src.utils.helpers import (mask_hidden, mask_input, overwrite_conf,
                               set_seed)
from src.utils.optim import setup_optimization
from src.utils.setup_arg_parser import setup_arg_parser
from src.utils.MLP_helpers import BatchedMLP, WB_Batch, calculate_l1_sum
from torch.cuda.amp import GradScaler, autocast

def load_config(args):
    """Loads YAML config and overwrites with command-line arguments."""
    with open(args.conf, 'r') as f:
        conf = yaml.safe_load(f)
    conf = overwrite_conf(conf, vars(args))
    print("Configuration for Evaluation:")
    print(yaml.dump(conf, default_flow_style=False))
    
    return conf

def load_mlp_data(path, min_accuracy_threshold):
    """Loads MLP parameters and accuracy from checkpoints, filtering models."""
    ckpt_paths = [os.path.join(path, fname) for fname in os.listdir(path)]
    param_data = []
    for p in ckpt_paths:
        ckpt = torch.load(p)
        accuracy = ckpt.get('test_accuracy', 1.0)
        if accuracy > min_accuracy_threshold:
            param_data.append({
                'params': ckpt['model_params'], 
                'accuracy': accuracy
            })
    return param_data

def setup_dataloaders(conf, param_data, device):
    """Sets up and splits both MLP and CIFAR-10 datasets."""
    # MLP DataLoaders
    full_dataset_mlp = MLPGraphDataset(
        param_data,
        direction=conf['scalegmn_args']['direction'],
        equiv_on_hidden=mask_hidden(conf),
        get_first_layer_mask=mask_input(conf),
        device=device
    )
    n_total = len(full_dataset_mlp)
    n_train = int(conf['data']['train_ratio'] * n_total)
    n_val = int(conf['data']['val_ratio'] * n_total)
    n_test = n_total - n_train - n_val
    train_set_mlp, val_set_mlp, test_set_mlp = random_split(full_dataset_mlp, [n_train, n_val, n_test])

    mlp_loaders = {
        'train': DataLoader(train_set_mlp, batch_size=conf['batch_size'], shuffle=True),
        'val': DataLoader(val_set_mlp, batch_size=conf['batch_size'], shuffle=False),
        'test': DataLoader(test_set_mlp, batch_size=conf['batch_size'], shuffle=False)
    }
    print(f"MLP Dataset -> Train: {len(train_set_mlp)}, Val: {len(val_set_mlp)}, Test: {len(test_set_mlp)}")

    # CIFAR-10 DataLoaders
    transform = transforms.Compose([
        transforms.Grayscale(), transforms.Resize(8, antialias=True),
        transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))
    ])
    
    cifar_root = conf['cifar_data']['dataset_path']
    train_set_cifar = torchvision.datasets.CIFAR10(root=cifar_root, train=True, download=True, transform=transform)
    full_testset_cifar = torchvision.datasets.CIFAR10(root=cifar_root, train=False, download=True, transform=transform)

    val_size = len(full_testset_cifar) // 2
    test_size = len(full_testset_cifar) - val_size
    val_set_cifar, test_set_cifar = random_split(full_testset_cifar, [val_size, test_size])

    cifar_loaders = {
        'train':  DataLoader(train_set_cifar, batch_size=int(conf['cifar_data']['batch_fraction'] * len(train_set_cifar)), shuffle=True, num_workers=conf["num_workers"]),
        'val': DataLoader(val_set_cifar,  batch_size=len(val_set_cifar), shuffle=False, num_workers=conf["num_workers"]),
    }
    return mlp_loaders, cifar_loaders

def residual_param_update(weights, biases, delta_weights, delta_biases):
    """Applies the predicted residual update to the MLP parameters."""
    new_weights = [weights[j] + delta_weights[j] for j in range(len(weights))]
    new_biases = [biases[j] + delta_biases[j] for j in range(len(biases))]
    return new_weights, new_biases

def save_best_model(net, conf, output_path):
    """Saves the model state with a descriptive filename."""
    hyperparams = {
        'lr': conf['optimization']['optimizer_args']['lr'],
        'batch': conf['batch_size'],
        'layers': conf['scalegmn_args']['num_layers'],
        'wd': conf['optimization']['optimizer_args']['weight_decay'],
        'dropout': conf['scalegmn_args']['gnn_args']['dropout'],
        'cifarbatch': conf['cifar_data']['batch_fraction'],
        'l1': conf['l1_lambda'],
        'activation': conf['data']['activation_function'],
    }
    break_symmetry = conf.get("scalegmn_args", {}).get("mlp_args", {}).get("break_symmetry", False)
    model_name = (
        f"gnn_mlp_lr{hyperparams['lr']}"
        f"_batch{hyperparams['batch']}"
        f"_layers{hyperparams['layers']}"
        f"_wd{hyperparams['wd']}"
        f"_dropout{hyperparams['dropout']}"
        f"_cifarbatch{hyperparams['cifarbatch']}"
        f"_l1{hyperparams['l1']}"
        f"_{hyperparams['activation']}"
    )
    if break_symmetry:
        model_name += "_broken"
    model_name += ".pt"
    save_path = os.path.join(output_path, model_name)
    
    torch.save(net.state_dict(), save_path)
    print(f"Model saved to {save_path}")

def process_batch(net, MLP, mlp_batch, cifar_batch, device, l1_lambda):
    """
    Runs a single forward pass and computes losses.
    This function is used by both the training and validation loops.
    """
    graph_batch, w_b = mlp_batch
    images, targets = cifar_batch

    # Move all data to the correct device
    images, targets = images.to(device), targets.to(device)
    graph_batch = graph_batch.to(device)
    weights = tuple(w.to(device) for w in w_b.weights)
    biases = tuple(b.to(device) for b in w_b.biases)

    # Prepare inputs
    batch_size = graph_batch.num_graphs
    targets_exp = targets.unsqueeze(0).expand(batch_size, -1)

    # Forward pass: GNN predicts updates, then MLP uses updated params
    delta_weights, delta_biases = net(graph_batch, weights, biases)
    new_weights, new_biases = residual_param_update(weights, biases, delta_weights, delta_biases)
    new_w_b = WB_Batch(weights=tuple(new_weights), biases=tuple(new_biases))
    outputs = MLP(images, new_w_b)

    # Calculate losses
    ce_loss = F.cross_entropy(
        outputs.reshape(-1, outputs.size(-1)),
        targets_exp.reshape(-1),
        reduction='mean'
    )
    l1 = calculate_l1_sum(new_w_b) / graph_batch.num_graphs
    l1_loss = l1 * l1_lambda
    total_loss = ce_loss + l1_loss
    
    return total_loss, ce_loss, l1

def main(args):
    """Main function to run the training and evaluation pipeline."""
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
    
    # --- SETUP ---
    conf = load_config(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(conf['train_args']['seed'])
    scaler = GradScaler(enabled=(device.type == 'cuda'))

    param_data = load_mlp_data(
        conf['data']['mlp_params_path'],
        conf['data']['min_accuracy_threshold'] # Use config for threshold
    )
    
    mlp_loaders, cifar_loaders = setup_dataloaders(conf, param_data, device)
    cifar_train_loader = cifar_loaders['train']
    cifar_train_iter = iter(cifar_train_loader)
    
    cifar_val_batch_cpu = next(iter(cifar_loaders['val'])) # Full valset
    cifar_val_images, cifar_val_targets = cifar_val_batch_cpu
    cifar_val_batch_gpu = (cifar_val_images.to(device), cifar_val_targets.to(device))

    net = ScaleGMN_equiv(conf['scalegmn_args']).to(device)
    MLP = BatchedMLP().to(device)
    print(f"Total GNN parameters: {sum(p.numel() for p in net.parameters())}")

    opt_conf = conf['optimization']
    model_params = [p for p in net.parameters() if p.requires_grad]
    optimizer, scheduler = setup_optimization(
        model_params,
        optimizer_name=opt_conf['optimizer_name'],
        optimizer_args=opt_conf['optimizer_args'],
        scheduler_args=opt_conf['scheduler_args']
    )
    l1_lambda = conf['l1_lambda']
    best_val_loss = float('inf')
    patience = conf['optimization']['scheduler_args']['patience']
    patience_counter = 0
    print(f"Early stopping patience: {patience}")
    
    print("\nStarting training loop...")
    for epoch in range(conf['train_args']['num_epochs']):
        epoch_start_time = time.time()
        
        # --- TRAINING ---
        net.train()
        epoch_train_total_loss, epoch_train_ce_loss, epoch_train_l1_loss = 0.0, 0.0, 0.0
        try:
            cifar_images, cifar_targets = next(cifar_train_iter)
        except StopIteration:
            # DataLoader is exhausted, reset it for the next pass
            cifar_train_iter = iter(cifar_train_loader)
            cifar_images, cifar_targets = next(cifar_train_iter)
        
        current_cifar_batch_gpu = (cifar_images.to(device), cifar_targets.to(device))
        
        train_progress_bar = tqdm(mlp_loaders['train'], desc=f"Epoch {epoch+1}/{conf['train_args']['num_epochs']} [Train]")
        for mlp_batch in train_progress_bar: 
            optimizer.zero_grad()

            with autocast(enabled=(device.type == 'cuda')):
                total_loss, ce_loss, l1_loss = process_batch(net, MLP, mlp_batch, current_cifar_batch_gpu, device, l1_lambda)
            scaler.scale(total_loss).backward()
            
            if conf['optimization']['clip_grad']:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(net.parameters(), conf['optimization']['clip_grad_max_norm'])
            
            scaler.step(optimizer)
            scaler.update()

            scheduler[0].step()

            epoch_train_total_loss += total_loss.item()
            epoch_train_ce_loss += ce_loss.item()
            epoch_train_l1_loss += l1_loss.item()

        avg_train_total_loss = epoch_train_total_loss / len(mlp_loaders['train'])
        avg_train_ce_loss = epoch_train_ce_loss / len(mlp_loaders['train'])
        avg_train_l1 = epoch_train_l1_loss / len(mlp_loaders['train'])

        net.eval()
        epoch_val_total_loss, epoch_val_ce_loss, epoch_val_l1_loss = 0.0, 0.0, 0.0
        val_progress_bar = tqdm(mlp_loaders['val'], desc=f"Epoch {epoch+1}/{conf['train_args']['num_epochs']} [Val]")
        with torch.no_grad():
            for mlp_batch in val_progress_bar:
                with autocast(enabled=(device.type == 'cuda')):
                    total_loss, ce_loss, l1_loss = process_batch(net, MLP, mlp_batch, cifar_val_batch_gpu, device, l1_lambda)
                
                epoch_val_total_loss += total_loss.item()
                epoch_val_ce_loss += ce_loss.item()
                epoch_val_l1_loss += l1_loss.item()

        avg_val_total_loss = epoch_val_total_loss / len(mlp_loaders['val'])
        avg_val_ce_loss = epoch_val_ce_loss / len(mlp_loaders['val'])
        avg_val_l1 = epoch_val_l1_loss / len(mlp_loaders['val'])

        # --- LOGGING ---
        print(f"\n--- Epoch {epoch + 1} Summary ---")
        print(f"Time: {time.time() - epoch_start_time:.2f}s")
        print(f"Train Loss -> CE: {avg_train_ce_loss:.4f}, L1: {avg_train_l1:.4f} | Total: {avg_train_total_loss:.4f}")
        print(f"Val Loss   -> CE: {avg_val_ce_loss:.4f}, L1: {avg_val_l1:.4f} | Total: {avg_val_total_loss:.4f}")

        if avg_val_total_loss < best_val_loss:
            best_val_loss = avg_val_total_loss
            patience_counter = 0
            print(f"New best validation loss: {best_val_loss:.4f}. Saving model...")
            save_best_model(net, conf, conf['train_args']['output_path'])
        else:
            patience_counter += 1

        print("-" * 25)

        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs.")
            break

    print(f"\n✓ Training finished. Best validation loss: {best_val_loss:.4f}")
    
if __name__ == '__main__':
    arg_parser = setup_arg_parser()
    args = arg_parser.parse_args()
    main(args)