import os
import yaml
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import torch_geometric
from torch.cuda.amp import autocast

from src.utils.setup_arg_parser import setup_arg_parser
from src.utils.helpers import overwrite_conf, set_seed, mask_input, mask_hidden
from src.scalegmn.models import ScaleGMN_custom
from src.data import dataset
from src.utils.CNN_helpers import DifferentiableCNN


def load_config(args):
    """Loads YAML config and overwrites with command-line arguments."""
    with open(args.conf, 'r') as f:
        conf = yaml.safe_load(f)
    conf = overwrite_conf(conf, vars(args))
    print("Configuration for Evaluation:")
    print(yaml.dump(conf, default_flow_style=False))
    
    return conf


def setup_dataloaders(conf):
    """Sets up CNN dataset and CIFAR test loader."""
    equiv_on_hidden = mask_hidden(conf)
    get_first_layer_mask = mask_input(conf)

    test_set_cnn = dataset(
        conf['data'],
        split='test',
        debug=conf["debug"],
        direction=conf['scalegmn_args']['direction'],
        equiv_on_hidden=equiv_on_hidden,
        get_first_layer_mask=get_first_layer_mask
    )

    cnn_loader = torch_geometric.loader.DataLoader(
        dataset=test_set_cnn,
        batch_size=conf['batch_size'],
        shuffle=False,
        num_workers=conf["num_workers"],
        pin_memory=True
    )

    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    cifar_root = conf['cifar_data']['dataset_path']
    full_testset_cifar = torchvision.datasets.CIFAR10(
        root=cifar_root,
        train=False,
        download=True,
        transform=transform
    )

    val_size = len(full_testset_cifar) // 2
    test_size = len(full_testset_cifar) - val_size
    _, test_set_cifar = random_split(full_testset_cifar, [val_size, test_size])

    cifar_test_loader = DataLoader(
        test_set_cifar,
        batch_size=len(test_set_cifar),
        shuffle=False,
        num_workers=conf["num_workers"]
    )

    return cnn_loader, cifar_test_loader


def evaluate(net, CNN, cnn_loader, cifar_loader, device, l1_lambda, use_amp=True):
    """Evaluates the GNN + CNN model on test set, before and after transformation."""
    net.eval()
    CNN.eval()

    # --- Metrics BEFORE ---
    total_loss_before, ce_loss_before, l1_loss_before = 0.0, 0.0, 0.0
    individual_accuracies_before = []

    # --- Metrics AFTER ---
    total_loss_after, ce_loss_after, l1_loss_after = 0.0, 0.0, 0.0
    individual_accuracies_after = []

    cifar_images, cifar_targets = next(iter(cifar_loader))
    cifar_images, cifar_targets = cifar_images.to(device), cifar_targets.to(device)

    with torch.no_grad():
        progress_bar = tqdm(cnn_loader, desc="Evaluating on Test Set")
        for cnn_batch in progress_bar:
            cnn_batch = cnn_batch.to(device)

            with autocast(enabled=use_amp):
                # --- Extract original CNN params ---
                batch_size = cnn_batch.num_graphs
                num_nodes_total, node_dim = cnn_batch.x.shape
                nodes_per_graph = num_nodes_total // batch_size
                num_edges_total, edge_dim = cnn_batch.edge_attr.shape
                edges_per_graph = num_edges_total // batch_size

                node_features_batched = cnn_batch.x.view(batch_size, nodes_per_graph, node_dim)
                edge_features_batched = cnn_batch.edge_attr.view(batch_size, edges_per_graph, edge_dim)

                # Repeat targets for each CNN
                targets_exp = cifar_targets.repeat(batch_size)

                # === 1. Evaluate BEFORE transformation ===
                outputs_before = CNN(cifar_images, node_features_batched, edge_features_batched)
                ce_b = F.cross_entropy(outputs_before, targets_exp)
                l1_b = CNN.sum_abs_params(node_features_batched, edge_features_batched).mean()
                loss_b = ce_b + (l1_lambda * l1_b)

                total_loss_before += loss_b.item()
                ce_loss_before += ce_b.item()
                l1_loss_before += l1_b.item()

                _, predicted_b = outputs_before.max(1)
                is_correct_b = predicted_b.eq(targets_exp)
                acc_b = is_correct_b.view(batch_size, len(cifar_targets)).float().mean(dim=1)
                individual_accuracies_before.extend(acc_b.cpu().numpy())

                # === 2. Predict residuals and evaluate AFTER transformation ===
                delta_nodes, delta_edges = net(cnn_batch.clone())
                new_nodes = node_features_batched + delta_nodes
                new_edges = edge_features_batched + delta_edges

                outputs_after = CNN(cifar_images, new_nodes, new_edges)
                ce_a = F.cross_entropy(outputs_after, targets_exp)
                l1_a = CNN.sum_abs_params(new_nodes, new_edges).mean()
                loss_a = ce_a + (l1_lambda * l1_a)

                total_loss_after += loss_a.item()
                ce_loss_after += ce_a.item()
                l1_loss_after += l1_a.item()

                _, predicted_a = outputs_after.max(1)
                is_correct_a = predicted_a.eq(targets_exp)
                acc_a = is_correct_a.view(batch_size, len(cifar_targets)).float().mean(dim=1)
                individual_accuracies_after.extend(acc_a.cpu().numpy())

    # --- Process BEFORE stats ---
    stats_before = (
        total_loss_before / len(cnn_loader),
        ce_loss_before / len(cnn_loader),
        l1_loss_before / len(cnn_loader),
        np.mean(individual_accuracies_before),
        np.var(individual_accuracies_before),
        np.max(individual_accuracies_before),
        np.min(individual_accuracies_before)
    )

    # --- Process AFTER stats ---
    stats_after = (
        total_loss_after / len(cnn_loader),
        ce_loss_after / len(cnn_loader),
        l1_loss_after / len(cnn_loader),
        np.mean(individual_accuracies_after),
        np.var(individual_accuracies_after),
        np.max(individual_accuracies_after),
        np.min(individual_accuracies_after)
    )

    return stats_before, stats_after


def main(args):
    """Main function to run CNN evaluation pipeline."""
    print("--- SCRIPT STARTED ---")
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"

    conf = load_config(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seed(conf['train_args']['seed'])
    cnn_loader, cifar_loader = setup_dataloaders(conf)

    activation_function_str = conf['data']['activation_function']
    if activation_function_str == 'relu':
        activation_function = F.relu
    elif activation_function_str == 'tanh':
        activation_function = torch.tanh
    else:
        raise ValueError(f"Unknown activation function: {activation_function_str}")
    
    CNN = DifferentiableCNN(activation=activation_function).to(device)
    net = ScaleGMN_custom(conf['scalegmn_args']).to(device)

    hyperparams = {
        'lr': conf['optimization']['optimizer_args']['lr'],
        'batch': conf['batch_size'],
        'layers': conf['scalegmn_args']['num_layers'],
        'wd': conf['optimization']['optimizer_args']['weight_decay'],
        'dropout': conf['scalegmn_args']['gnn_args']['dropout'],
        'cifarbatch': conf['cifar_data']['batch_fraction'],
        'l1': conf['l1_lambda'],
        'activation': conf['data']['activation_function'],
    }

    break_symmetry = conf.get("scalegmn_args", {}).get("mlp_args", {}).get("break_symmetry", False)
    model_name = (
        f"gnn_cnn_lr{hyperparams['lr']}"
        f"_batch{hyperparams['batch']}"
        f"_layers{hyperparams['layers']}"
        f"_wd{hyperparams['wd']}"
        f"_dropout{hyperparams['dropout']}"
        f"_cifarbatch{hyperparams['cifarbatch']}"
        f"_l1{hyperparams['l1']}"
        f"_{hyperparams['activation']}"
    )
    if break_symmetry:
        model_name += "_broken"
    model_name += ".pt"
    checkpoint_path = os.path.join(conf['train_args']['output_path'], model_name)
    try:
        net.load_state_dict(torch.load(checkpoint_path, map_location=device))
        print(f"Successfully loaded model from {checkpoint_path}")
    except FileNotFoundError:
        print(f"Error: Model checkpoint not found at {checkpoint_path}")
        return

    print("\nStarting evaluation...")
    stats_before, stats_after = evaluate(net, CNN, cnn_loader, cifar_loader, device, conf['l1_lambda'])

    (avg_loss_b, avg_ce_b, avg_l1_b, acc_mean_b, acc_var_b, acc_max_b, acc_min_b) = stats_before
    (avg_loss_a, avg_ce_a, avg_l1_a, acc_mean_a, acc_var_a, acc_max_a, acc_min_a) = stats_after

    print("\n--- Evaluation Results (BEFORE Transformation) ---")
    print(f"Total Loss:           {avg_loss_b:.4f}")
    print(f"Cross-Entropy Loss:   {avg_ce_b:.4f}")
    print(f"L1 Penalty:           {avg_l1_b:.4f}")
    print("-" * 28)
    print(f"Average Accuracy:     {acc_mean_b*100:.2f}%")
    print(f"Min Accuracy:         {acc_min_b*100:.2f}%")
    print(f"Max Accuracy:         {acc_max_b*100:.2f}%")
    print(f"Variance of Accuracy: {acc_var_b*100:.4f}%")

    print("\n--- Evaluation Results (AFTER Transformation) ---")
    print(f"Total Loss:           {avg_loss_a:.4f}")
    print(f"Cross-Entropy Loss:   {avg_ce_a:.4f}")
    print(f"L1 Penalty:           {avg_l1_a:.4f}")
    print("-" * 28)
    print(f"Average Accuracy:     {acc_mean_a*100:.2f}%")
    print(f"Min Accuracy:         {acc_min_a*100:.2f}%")
    print(f"Max Accuracy:         {acc_max_a*100:.2f}%")
    print(f"Variance of Accuracy: {acc_var_a*100:.4f}%")

if __name__ == "__main__":
    arg_parser = setup_arg_parser()
    args = arg_parser.parse_args()
    main(args=args)