from gnnboundary import *
# from gnnboundary.models.gcn import GCNClassifier
# from gnnboundary.utils.criterion import (
#     WeightedCriterion, ClassScoreCriterion, EmbeddingCriterion,
#     MeanPenalty, NormPenalty, KLDivergencePenalty, BudgetPenalty
# )
from gnnboundary.utils.boundary_generator import GraphGenerator
from gnnboundary.datasets import ENZYMESDataset, CollabDataset, MotifDataset

import torch
import pandas as pd
import numpy as np
import argparse
import sys
import yaml
import os
from pathlib import Path


def load_config():
    config_path = Path('../config/config.yml')
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        return config
    except FileNotFoundError:
        print(f"Error: config.yml not found at {config_path}")
        sys.exit(1)


def load_constants(dataset_name):
    constants_path = Path('../config/constants.yml')
    try:
        with open(constants_path, 'r') as f:
            constants = yaml.safe_load(f)
        
        class_names_key = f"{dataset_name}_CLASS_NAMES"
        if class_names_key not in constants:
            raise KeyError(f"{class_names_key} not found in constants.yml")
        
        return constants[class_names_key]
    except FileNotFoundError:
        print(f"Warning: constants.yml not found at {constants_path}. Using default class names.")
        # Default class names based on dataset
        if dataset_name == "ENZYMES":
            return {
                0: 'EC1', 1: 'EC2', 2: 'EC3',
                3: 'EC4', 4: 'EC5', 5: 'EC6'
            }
        elif dataset_name == "COLLAB":
            return {
                0: "High Energy", 1: "Condensed Matter", 2: "Astro"
            }
        elif dataset_name == "Motif":
            return {
                0: "house", 1: "house_x", 2: "comp_4", 3: "comp_5"
            }


def parse_arguments():
    parser = argparse.ArgumentParser(description='GNNInterpreter Graph Generation')
    parser.add_argument('--dataset', type=str, required=True,
                      choices=['ENZYMES', 'COLLAB', 'Motif'],
                      help='Dataset to use')
    parser.add_argument('--num_runs', type=int,
                      help='Override number of graph generation attempts')
    parser.add_argument('--iterations', type=int,
                      help='Override number of optimization steps')
    parser.add_argument('--seed', type=int,
                      help='Override random seed')
    
    return parser.parse_args()


def save_results_to_csv(results, filename, num_runs, iterations, dataset_name):
    class_names = load_constants(dataset_name)
    
    df = pd.DataFrame(
        index=[class_names[i] for i in sorted(results.keys())]
    )
    
    df['Success Rate'] = [results[i][0] for i in sorted(results.keys())]
    
    convergence_data = []
    for i in sorted(results.keys()):
        if isinstance(results[i][1], (list, np.ndarray)):
            avg = np.mean(results[i][1])
        else:
            avg = float(results[i][1])
        convergence_data.append(avg)
    
    df['Average Convergence Iteration'] = convergence_data
    
    std_data = []
    for i in sorted(results.keys()):
        if isinstance(results[i][1], (list, np.ndarray)):
            std = np.std(results[i][1])
        else:
            std = 0.0
        std_data.append(std)
    
    df['Std Convergence Rate'] = std_data
    df['Number of Graphs'] = num_runs
    df['Number of Iterations'] = iterations
    
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    df = df.round(4)
    df.to_csv(filename)
    print(f"\nResults saved to {filename}")
    print("\nResults Summary:")
    print(df)
    
    return df


def setup_model_and_dataset(config, dataset_name, seed):
    if dataset_name == "ENZYMES":
        dataset = ENZYMESDataset(seed=seed)
    elif dataset_name == "COLLAB":
        dataset = CollabDataset(seed=seed)
    elif dataset_name == "Motif":
        dataset = MotifDataset(seed=seed)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    # Load model settings from constants
    constants_path = Path('../config/constants.yml')
    try:
        with open(constants_path, 'r') as f:
            constants = yaml.safe_load(f)
        # Switch from the constant.yml to config.yml file to train other model configurations
        model_settings = constants['PRETRAINED_MODEL_SETTINGS'][dataset_name]
    except (FileNotFoundError, KeyError) as e:
        print(f"Error loading model settings from constants.yml: {e}")
        sys.exit(1)

    model = GCNClassifier(
        node_features=len(dataset.NODE_CLS),
        num_classes=len(dataset.GRAPH_CLS),
        hidden_channels=model_settings['hidden_channels'],
        num_layers=model_settings['num_layers']
    )
    
    model_path = config['GNNInterpreter'][dataset_name]['model_path']
    try:
        model.load_state_dict(torch.load(model_path))
    except FileNotFoundError:
        print(f"Error: Model checkpoint not found at {model_path}")
        sys.exit(1)
    
    return dataset, model


def generate_graphs_for_class(cls, dataset, model, mean_embeds, config, dataset_name):
    dataset_config = config['GNNInterpreter'][dataset_name]
    
    s = GraphSampler(
        max_nodes=dataset_config['max_nodes'],
        temperature=config['sampler']['temperature'],
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=config['sampler']['learn_node_feat']
    )

    o = torch.optim.SGD(s.parameters(), lr=config['optimizer']['learning_rate'])
    trainer = Trainer(
        sampler=s,
        discriminator=model,
        criterion=WeightedCriterion([
            dict(key="logits", criterion=ClassScoreCriterion(class_idx=cls), mode="maximize", 
                 weight=dataset_config['weights']['logits']),
            dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls]), 
                 weight=dataset_config['weights']['embeds']),
            dict(key="logits", criterion=MeanPenalty(), 
                 weight=dataset_config['weights']['mean_penalty']),
            dict(key="omega", criterion=NormPenalty(order=1), 
                 weight=dataset_config['weights']['omega_l1']),
            dict(key="omega", criterion=NormPenalty(order=2), 
                 weight=dataset_config['weights']['omega_l2']),
            dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), 
                 weight=dataset_config['weights']['theta_pairs']),
        ]),
        optimizer=o,
        scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=config['optimizer']['scheduler']['gamma']),
        dataset=dataset,
        budget_penalty=BudgetPenalty(
            budget=config['training']['budget']['initial'],
            order=config['training']['budget']['order'],
            beta=config['training']['budget']['beta']
        ),
    )
    
    generator = GraphGenerator(trainer.sampler, dataset, trainer, model)
    
    out = generator(
        num_graphs=config['experiment']['num_runs'],
        iterations=config['experiment']['iterations'],
        save_graphs=True,
        cls=cls,
        strategy="interpreter",
        target_probs={cls: tuple(dataset_config['target_probs'])},
        target_size=dataset_config['max_nodes'],
        w_budget_init=dataset_config['budget_control']['w_init'],
        w_budget_inc=dataset_config['budget_control']['w_inc'],
        w_budget_dec=dataset_config['budget_control']['w_dec'],
        k_samples=dataset_config['k_samples']
    )
    
    return out


def main():
    args = parse_arguments()
    config = load_config()
    
    # Override config with command line arguments if provided
    if args.num_runs is not None:
        config['experiment']['num_runs'] = args.num_runs
    if args.iterations is not None:
        config['experiment']['iterations'] = args.iterations
    if args.seed is not None:
        config['training']['seed'] = args.seed
    
    dataset_name = args.dataset
    dataset_config = config['GNNInterpreter'][dataset_name]
    
    # Setup
    dataset, model = setup_model_and_dataset(config, dataset_name, config['training']['seed'])
    
    # Process dataset
    dataset_list_gt = dataset.split_by_class()
    dataset_list_pred = dataset.split_by_pred(model)
    
    # Model evaluation
    evaluation = dataset.model_evaluate(model)
    print("\nModel Evaluation:")
    print(evaluation)
    
    # Calculate mean embeddings
    mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]
    
    # Process each class
    results = {}
    # Get number of classes based on dataset
    num_classes = len(dataset.GRAPH_CLS)
    classes = list(range(num_classes))
    
    for cls in classes:
        print(f"\nProcessing Class {cls}...")
        success_rate, conv_iter = generate_graphs_for_class(cls, dataset, model, mean_embeds, config, dataset_name)
        results[cls] = (success_rate, conv_iter)
        
        if len(conv_iter) > 0:
            avg_conv_iter = np.mean(conv_iter)
            print(f"Average Convergence Iteration: {avg_conv_iter:.2f}")
        print(f"Success rate: {success_rate:.2f}")
    
    # Save results
    output_dir = dataset_config['output_dir']
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    save_path = os.path.join(output_dir, "results.csv")
    
    df = save_results_to_csv(results, save_path, 
                           config['experiment']['num_runs'], 
                           config['experiment']['iterations'],
                           dataset_name)


if __name__ == "__main__":
    main()