import torch
import time
from tqdm import tqdm
import torch_geometric
import yaml
import os
from src.utils.setup_arg_parser import setup_arg_parser
from src.utils.helpers import overwrite_conf, set_seed
from src.scalegmn.models import ScaleGMN_custom
from src.data import dataset
from torch.utils.data import random_split, DataLoader
from src.utils.optim import setup_optimization
from torch.nn import CrossEntropyLoss as criterion
from src.utils.helpers import overwrite_conf, set_seed, mask_input, mask_hidden
from src.utils.CNN_helpers import DifferentiableCNN
import torchvision.transforms as transforms
import torchvision
from torch.nn import CrossEntropyLoss as criterion
from torch.cuda.amp import GradScaler, autocast

def load_config(args):
    """Loads YAML config and overwrites with command-line arguments."""
    with open(args.conf, 'r') as f:
        conf = yaml.safe_load(f)
    conf = overwrite_conf(conf, vars(args))
    print("Configuration for Evaluation:")
    print(yaml.dump(conf, default_flow_style=False))
    
    return conf

def setup_dataloaders(conf):
    """Sets up and splits both CNN and CIFAR-10 datasets."""
    # CNN DataLoaders
    equiv_on_hidden = mask_hidden(conf)
    get_first_layer_mask = mask_input(conf)

    train_set_cnn = dataset(conf['data'],
                        split='train',
                        debug=conf["debug"],
                        direction=conf['scalegmn_args']['direction'],
                        equiv_on_hidden=equiv_on_hidden,
                        get_first_layer_mask=get_first_layer_mask)
    val_set_cnn = dataset(conf['data'],
                        split='val',
                        debug=conf["debug"],
                        direction=conf['scalegmn_args']['direction'],
                        equiv_on_hidden=equiv_on_hidden,
                        get_first_layer_mask=get_first_layer_mask)
    test_set_cnn = dataset(conf['data'],
                        split='test',
                        debug=conf["debug"],
                        direction=conf['scalegmn_args']['direction'],
                        equiv_on_hidden=equiv_on_hidden,
                        get_first_layer_mask=get_first_layer_mask)
    print(f'CNN Dataset -> Train: {len(train_set_cnn)}, Val: {len(val_set_cnn)}, Test: {len(test_set_cnn)}')
    cnn_loaders = {
        'train': torch_geometric.loader.DataLoader(dataset=train_set_cnn, batch_size=conf['batch_size'], shuffle=True, num_workers=conf["num_workers"], pin_memory=True, sampler=None),
        'val': torch_geometric.loader.DataLoader(dataset=val_set_cnn, batch_size=conf['batch_size'], shuffle=False, num_workers=conf["num_workers"], pin_memory=True, sampler=None),
        'test': torch_geometric.loader.DataLoader(dataset=test_set_cnn, batch_size=conf['batch_size'], shuffle=False, num_workers=conf["num_workers"], pin_memory=True, sampler=None)
    }

    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    cifar_root = conf['cifar_data']['dataset_path']
    train_set_cifar = torchvision.datasets.CIFAR10(root=cifar_root, train=True, download=True, transform=transform)
    full_testset_cifar = torchvision.datasets.CIFAR10(root=cifar_root, train=False, download=True, transform=transform)

    val_size = len(full_testset_cifar) // 2
    test_size = len(full_testset_cifar) - val_size
    val_set_cifar, test_set_cifar = random_split(full_testset_cifar, [val_size, test_size])


    cifar_loaders = {
        'train':  DataLoader(train_set_cifar, batch_size=int(conf['cifar_data']['batch_fraction'] * len(train_set_cifar)), shuffle=True, num_workers=conf["num_workers"]),
        'val': DataLoader(val_set_cifar,  batch_size=len(val_set_cifar), shuffle=False, num_workers=conf["num_workers"]),
    }
    return cnn_loaders, cifar_loaders

def save_best_model(net, conf, output_path):
    """Saves the model state with a descriptive filename."""
    hyperparams = {
        'lr': conf['optimization']['optimizer_args']['lr'],
        'batch': conf['batch_size'],
        'layers': conf['scalegmn_args']['num_layers'],
        'wd': conf['optimization']['optimizer_args']['weight_decay'],
        'dropout': conf['scalegmn_args']['gnn_args']['dropout'],
        'cifarbatch': conf['cifar_data']['batch_fraction'],
        'l1': conf['l1_lambda'],
        'activation': conf['data']['activation_function'],
    }

    break_symmetry = conf.get("scalegmn_args", {}).get("mlp_args", {}).get("break_symmetry", False)
    model_name = (
        f"gnn_cnn_lr{hyperparams['lr']}"
        f"_batch{hyperparams['batch']}"
        f"_layers{hyperparams['layers']}"
        f"_wd{hyperparams['wd']}"
        f"_dropout{hyperparams['dropout']}"
        f"_cifarbatch{hyperparams['cifarbatch']}"
        f"_l1{hyperparams['l1']}"
        f"_{hyperparams['activation']}"
    )
    if break_symmetry:
        model_name += "_broken"
    model_name += ".pt"
    save_path = os.path.join(output_path, model_name)

    torch.save(net.state_dict(), save_path)
    print(f"Model saved to {save_path}")

def process_batch(net, CNN, cnn_batch, cifar_batch, device, l1_lambda, CE):
    """
    Runs a single forward pass and computes losses.
    This function is used by both the training and validation loops.
    """
    images, targets = cifar_batch
    images, targets = images.to(device), targets.to(device)
    cnn_batch = cnn_batch.to(device)
    input_batch = cnn_batch.clone().to(device)
    

    node_features_flat = cnn_batch.x
    edge_features_flat = cnn_batch.edge_attr
    batch_size = cnn_batch.num_graphs
    targets_repeated = targets.repeat(batch_size)

    num_nodes_total, node_dim = node_features_flat.shape
    nodes_per_graph = num_nodes_total // batch_size
    num_edges_total, edge_dim = edge_features_flat.shape
    edges_per_graph = num_edges_total // batch_size
    node_features_batched = node_features_flat.view(batch_size, nodes_per_graph, node_dim)
    edge_features_batched = edge_features_flat.view(batch_size, edges_per_graph, edge_dim)

    node_features_out, edge_features_out = net(input_batch)
    node_features_out += node_features_batched
    edge_features_out += edge_features_batched

    out_after = CNN(images, node_features_out, edge_features_out)
    ce_loss_after = CE(out_after, targets_repeated)
    l1_loss_after = CNN.sum_abs_params(node_features_out, edge_features_out).mean()

    total_loss = ce_loss_after + l1_lambda * l1_loss_after
    return total_loss, ce_loss_after, l1_loss_after

def main(args):
    """Main function to run the training and evaluation pipeline."""
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
    CE = criterion()

    # --- SETUP ---
    conf = load_config(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(conf['train_args']['seed'])
    scaler = GradScaler(enabled=(device.type == 'cuda'))

    cnn_loaders, cifar_loaders = setup_dataloaders(conf)
    cifar_train_loader = cifar_loaders['train']
    cifar_train_iter = iter(cifar_train_loader)

    cifar_val_batch_cpu = next(iter(cifar_loaders['val'])) # Full valset
    cifar_val_images, cifar_val_targets = cifar_val_batch_cpu
    cifar_val_batch_gpu = (cifar_val_images.to(device), cifar_val_targets.to(device))

    net = ScaleGMN_custom(conf['scalegmn_args']).to(device)
    CNN = DifferentiableCNN().to(device)
    print(f"Total GNN parameters: {sum(p.numel() for p in net.parameters())}")

    opt_conf = conf['optimization']
    model_params = [p for p in net.parameters() if p.requires_grad]
    optimizer, scheduler = setup_optimization(
        model_params,
        optimizer_name=opt_conf['optimizer_name'],
        optimizer_args=opt_conf['optimizer_args'],
        scheduler_args=opt_conf['scheduler_args']
    )
    l1_lambda = conf['l1_lambda']
    best_val_loss = float('inf')
    patience = conf['optimization']['scheduler_args']['patience']
    patience_counter = 0
    print(f"Early stopping patience: {patience}")

    print("\nStarting training loop...")
    for epoch in range(conf['train_args']['num_epochs']):
        epoch_start_time = time.time()
        try:
            cifar_images, cifar_targets = next(cifar_train_iter)
        except StopIteration:
            # DataLoader is exhausted, reset it for the next pass
            cifar_train_iter = iter(cifar_train_loader)
            cifar_images, cifar_targets = next(cifar_train_iter)
        current_cifar_batch_gpu = (cifar_images.to(device), cifar_targets.to(device))
        
        # --- TRAINING ---
        net.train()
        epoch_train_total_loss, epoch_train_ce_loss, epoch_train_l1_loss = 0.0, 0.0, 0.0
        
        train_progress_bar = tqdm(cnn_loaders['train'], desc=f"Epoch {epoch+1}/{conf['train_args']['num_epochs']} [Train]", leave=False)
        for cnn_batch in train_progress_bar:
            optimizer.zero_grad()

            with autocast(enabled=(device.type == 'cuda')):
                total_loss, ce_loss, l1_loss = process_batch(net, CNN, cnn_batch, current_cifar_batch_gpu, device, l1_lambda, CE)
            
            scaler.scale(total_loss).backward()

            if conf['optimization']['clip_grad']:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(net.parameters(), conf['optimization']['clip_grad_max_norm'])
            
            scaler.step(optimizer)
            scaler.update()

            scheduler[0].step()

            epoch_train_total_loss += total_loss.item()
            epoch_train_ce_loss += ce_loss.item()
            epoch_train_l1_loss += l1_loss.item()

        avg_train_total_loss = epoch_train_total_loss / len(cnn_loaders['train'])
        avg_train_ce_loss = epoch_train_ce_loss / len(cnn_loaders['train'])
        avg_train_l1 = epoch_train_l1_loss / len(cnn_loaders['train'])

        net.eval()
        epoch_val_total_loss, epoch_val_ce_loss, epoch_val_l1_loss = 0.0, 0.0, 0.0
        val_progress_bar = tqdm(cnn_loaders['val'], desc=f"Epoch {epoch+1}/{conf['train_args']['num_epochs']} [Val]", leave=False)
        with torch.no_grad():
            for cnn_batch in val_progress_bar:
                with autocast(enabled=(device.type == 'cuda')):
                    total_loss, ce_loss, l1_loss = process_batch(net, CNN, cnn_batch, cifar_val_batch_gpu, device, l1_lambda, CE)
                
                epoch_val_total_loss += total_loss.item()
                epoch_val_ce_loss += ce_loss.item()
                epoch_val_l1_loss += l1_loss.item()

        avg_val_total_loss = epoch_val_total_loss / len(cnn_loaders['val'])
        avg_val_ce_loss = epoch_val_ce_loss / len(cnn_loaders['val'])
        avg_val_l1 = epoch_val_l1_loss / len(cnn_loaders['val'])

        # --- LOGGING ---
        print(f"\n--- Epoch {epoch + 1} Summary ---")
        print(f"Time: {time.time() - epoch_start_time:.2f}s")
        print(f"Train Loss -> CE: {avg_train_ce_loss:.4f}, L1: {avg_train_l1:.4f} | Total: {avg_train_total_loss:.4f}")
        print(f"Val Loss   -> CE: {avg_val_ce_loss:.4f}, L1: {avg_val_l1:.4f} | Total: {avg_val_total_loss:.4f}")

        if avg_val_total_loss < best_val_loss:
            best_val_loss = avg_val_total_loss
            patience_counter = 0
            print(f"New best validation loss: {best_val_loss:.4f}. Saving model...")
            save_best_model(net, conf, conf['train_args']['output_path'])
        else:
            patience_counter += 1

        print("-" * 25)

        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch + 1} epochs.")
            break

    print(f"\n✓ Training finished. Best validation loss: {best_val_loss:.4f}")
    
if __name__ == '__main__':
    arg_parser = setup_arg_parser()
    args = arg_parser.parse_args()
    main(args)
