import os
import yaml
from tqdm import tqdm
import numpy as np

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader

from src.scalegmn.models import ScaleGMN_equiv
from src.data.cifar10_dataset import MLPGraphDataset
from src.utils.helpers import (mask_hidden, mask_input, overwrite_conf,
                               set_seed)
from src.utils.setup_arg_parser import setup_arg_parser
from src.utils.MLP_helpers import BatchedMLP, WB_Batch, calculate_l1_sum

def load_config(args):
    """Loads YAML config and overwrites with command-line arguments."""
    with open(args.conf, 'r') as f:
        conf = yaml.safe_load(f)
    conf = overwrite_conf(conf, vars(args))
    print("Configuration for Evaluation:")
    print(yaml.dump(conf, default_flow_style=False))
    
    return conf

def load_mlp_data(path, min_accuracy_threshold):
    """Loads MLP parameters and accuracy from checkpoints, filtering models."""
    ckpt_paths = [os.path.join(path, fname) for fname in os.listdir(path)]
    param_data = []
    for p in ckpt_paths:
        ckpt = torch.load(p)
        accuracy = ckpt.get('test_accuracy', 1.0)
        if accuracy > min_accuracy_threshold:
            param_data.append({
                'params': ckpt['model_params'],
                'accuracy': accuracy
            })
    return param_data

def setup_dataloaders(conf, param_data, device):
    """Sets up test dataloaders for MLP and CIFAR-10 datasets."""
    full_dataset_mlp = MLPGraphDataset(
        param_data,
        direction=conf['scalegmn_args']['direction'],
        equiv_on_hidden=mask_hidden(conf),
        get_first_layer_mask=mask_input(conf),
        device=device
    )
    n_total = len(full_dataset_mlp)
    n_train = int(conf['data']['train_ratio'] * n_total)
    n_val = int(conf['data']['val_ratio'] * n_total)
    n_test = n_total - n_train - n_val
    _, _, test_set_mlp = random_split(full_dataset_mlp, [n_train, n_val, n_test])

    mlp_test_loader = DataLoader(test_set_mlp, batch_size=conf['batch_size'], shuffle=False)
    print(f"MLP Test Dataset size: {len(test_set_mlp)}")

    transform = transforms.Compose([
        transforms.Grayscale(), transforms.Resize(8, antialias=True),
        transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))
    ])
    
    cifar_root = conf['cifar_data']['dataset_path']
    full_testset_cifar = torchvision.datasets.CIFAR10(root=cifar_root, train=False, download=True, transform=transform)

    val_size = len(full_testset_cifar) // 2
    test_size = len(full_testset_cifar) - val_size
    _, test_set_cifar = random_split(full_testset_cifar, [val_size, test_size])
    cifar_test_loader = DataLoader(test_set_cifar, batch_size=len(test_set_cifar), shuffle=False, num_workers=conf["num_workers"])
    
    return mlp_test_loader, cifar_test_loader

def residual_param_update(weights, biases, delta_weights, delta_biases):
    """Applies the predicted residual update to the MLP parameters."""
    new_weights = [weights[j] + delta_weights[j] for j in range(len(weights))]
    new_biases = [biases[j] + delta_biases[j] for j in range(len(biases))]
    return new_weights, new_biases

def evaluate(net, MLP, mlp_loader, cifar_loader, device, l1_lambda):
    """Evaluates the GNN model on the test set, reporting before and after transformation."""
    net.eval()
    MLP.eval()
    
    # Metrics for original MLPs
    total_loss_before, ce_loss_total_before, l1_loss_total_before = 0.0, 0.0, 0.0
    individual_accuracies_before = []

    # Metrics for GNN-transformed MLPs
    total_loss_after, ce_loss_total_after, l1_loss_total_after = 0.0, 0.0, 0.0
    individual_accuracies_after = []
    
    cifar_images, cifar_targets = next(iter(cifar_loader))
    cifar_images, cifar_targets = cifar_images.to(device), cifar_targets.to(device)

    with torch.no_grad():
        progress_bar = tqdm(mlp_loader, desc="Evaluating on Test Set")
        for mlp_batch in progress_bar:
            graph_batch, w_b = mlp_batch
            
            graph_batch = graph_batch.to(device)
            weights = tuple(w.to(device) for w in w_b.weights)
            biases = tuple(b.to(device) for b in w_b.biases)
            original_w_b = WB_Batch(weights=weights, biases=biases)

            batch_size = graph_batch.num_graphs
            targets_exp = cifar_targets.unsqueeze(0).expand(batch_size, -1)

            # --- 1. Evaluate BEFORE transformation ---
            outputs_before = MLP(cifar_images, original_w_b)
            
            ce_loss_b = F.cross_entropy(outputs_before.reshape(-1, outputs_before.size(-1)), targets_exp.reshape(-1), reduction='mean')
            l1_b = calculate_l1_sum(original_w_b) / batch_size
            loss_b = ce_loss_b + (l1_b * l1_lambda)
            
            total_loss_before += loss_b.item()
            ce_loss_total_before += ce_loss_b.item()
            l1_loss_total_before += l1_b.item()

            _, predicted_b = outputs_before.max(2)
            is_correct_b = predicted_b.eq(targets_exp)
            accuracies_this_batch_b = is_correct_b.float().mean(dim=1)
            individual_accuracies_before.extend(accuracies_this_batch_b.cpu().numpy())

            # --- 2. Evaluate AFTER transformation ---
            delta_weights, delta_biases = net(graph_batch, weights, biases)
            new_weights, new_biases = residual_param_update(weights, biases, delta_weights, delta_biases)
            new_w_b = WB_Batch(weights=tuple(new_weights), biases=tuple(new_biases))
            
            outputs_after = MLP(cifar_images, new_w_b)
            
            ce_loss_a = F.cross_entropy(outputs_after.reshape(-1, outputs_after.size(-1)), targets_exp.reshape(-1), reduction='mean')
            l1_a = calculate_l1_sum(new_w_b) / batch_size
            loss_a = ce_loss_a + (l1_a * l1_lambda)
            
            total_loss_after += loss_a.item()
            ce_loss_total_after += ce_loss_a.item()
            l1_loss_total_after += l1_a.item()

            _, predicted_a = outputs_after.max(2)
            is_correct_a = predicted_a.eq(targets_exp)
            accuracies_this_batch_a = is_correct_a.float().mean(dim=1)
            individual_accuracies_after.extend(accuracies_this_batch_a.cpu().numpy())

    # --- Process "Before" stats ---
    stats_before = (
        total_loss_before / len(mlp_loader),
        ce_loss_total_before / len(mlp_loader),
        l1_loss_total_before / len(mlp_loader),
        np.mean(individual_accuracies_before),
        np.var(individual_accuracies_before),
        np.max(individual_accuracies_before),
        np.min(individual_accuracies_before),
        individual_accuracies_before
    )

    # --- Process "After" stats ---
    stats_after = (
        total_loss_after / len(mlp_loader),
        ce_loss_total_after / len(mlp_loader),
        l1_loss_total_after / len(mlp_loader),
        np.mean(individual_accuracies_after),
        np.var(individual_accuracies_after),
        np.max(individual_accuracies_after),
        np.min(individual_accuracies_after),
        individual_accuracies_after
    )
    
    return stats_before, stats_after

def main(args):
    """Main function to run the evaluation pipeline."""
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
    
    conf = load_config(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(conf['train_args']['seed'])

    param_data = load_mlp_data(
        conf['data']['mlp_params_path'],
        conf['data']['min_accuracy_threshold']
    )
    
    mlp_test_loader, cifar_test_loader = setup_dataloaders(conf, param_data, device)

    net = ScaleGMN_equiv(conf['scalegmn_args']).to(device)
    MLP = BatchedMLP().to(device)
    print(f"Total GNN parameters: {sum(p.numel() for p in net.parameters())}")
    
    hyperparams = {
        'lr': conf['optimization']['optimizer_args']['lr'],
        'batch': conf['batch_size'],
        'layers': conf['scalegmn_args']['num_layers'],
        'wd': conf['optimization']['optimizer_args']['weight_decay'],
        'dropout': conf['scalegmn_args']['gnn_args']['dropout'],
        'cifarbatch': conf['cifar_data']['batch_fraction'],
        'l1': conf['l1_lambda'],
        'activation': conf['data']['activation_function'],
    }

    break_symmetry = conf.get("scalegmn_args", {}).get("mlp_args", {}).get("break_symmetry", False)
    model_name = (
        f"gnn_mlp_lr{hyperparams['lr']}"
        f"_batch{hyperparams['batch']}"
        f"_layers{hyperparams['layers']}"
        f"_wd{hyperparams['wd']}"
        f"_dropout{hyperparams['dropout']}"
        f"_cifarbatch{hyperparams['cifarbatch']}"
        f"_l1{hyperparams['l1']}"
        f"_{hyperparams['activation']}"
    )
    if break_symmetry:
        model_name += "_broken"
    model_name += ".pt"
    checkpoint_path = os.path.join(conf['train_args']['output_path'], model_name)
    
    try:
        net.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"Successfully loaded model from {checkpoint_path}")
        param_sum = sum(p.sum().item() for p in net.parameters())
        print(f"[Check] Loaded GNN Model Parameter Sum: {param_sum}")
    except FileNotFoundError:
        print(f"Error: Model checkpoint not found at {checkpoint_path}")
        return

    print("\nStarting evaluation...")
    stats_before, stats_after = evaluate(
        net, MLP, mlp_test_loader, cifar_test_loader, device, conf['l1_lambda']
    )

    (avg_loss_b, avg_ce_b, avg_l1_b, acc_mean_b, 
     acc_var_b, acc_max_b, acc_min_b, individual_accuracies_before) = stats_before
     
    (avg_loss_a, avg_ce_a, avg_l1_a, acc_mean_a, 
     acc_var_a, acc_max_a, acc_min_a, individual_accuracies_after) = stats_after

    print("\n--- Evaluation Results (BEFORE Transformation) ---")
    print(f"Total Loss:           {avg_loss_b:.4f}")
    print(f"Cross-Entropy Loss:   {avg_ce_b:.4f}")
    print(f"L1 Penalty:           {avg_l1_b:.4f}")
    print("-" * 28)
    print(f"Average Accuracy:     {acc_mean_b*100:.2f}%")
    print(f"Min Accuracy:         {acc_min_b*100:.2f}%")
    print(f"Max Accuracy:         {acc_max_b*100:.2f}%")
    print(f"Variance of Accuracy: {acc_var_b*100:.4f}%")
    
    print("\n--- Evaluation Results (AFTER Transformation) ----")
    print(f"Total Loss:           {avg_loss_a:.4f}")
    print(f"Cross-Entropy Loss:   {avg_ce_a:.4f}")
    print(f"L1 Penalty:           {avg_l1_a:.4f}")
    print("-" * 28)
    print(f"Average Accuracy:     {acc_mean_a*100:.2f}%")
    print(f"Min Accuracy:         {acc_min_a*100:.2f}%")
    print(f"Max Accuracy:         {acc_max_a*100:.2f}%")
    print(f"Variance of Accuracy: {acc_var_a*100:.4f}%")
    print(f"Sparsity: {(avg_l1_b - avg_l1_a) / avg_l1_b *100:.4f}%")

    results_filename = f"evaluation_results_{os.path.splitext(model_name)[0]}.npz"
    results_filepath = os.path.join(conf['train_args']['output_path'], results_filename)

    np.savez_compressed(
        results_filepath,
        accuracies_before=np.array(individual_accuracies_before),
        accuracies_after=np.array(individual_accuracies_after)
    )

if __name__ == '__main__':
    arg_parser = setup_arg_parser()
    args = arg_parser.parse_args()
    main(args)