import os
import yaml
import numpy as np
from ml_collections import ConfigDict

import torch
import torch_geometric
from tqdm import tqdm

from sacred import Experiment

from utils.data import load_dataset, make_dataset_splits, load_dataset_splits
from utils.split import SplitManager, node_induced_subgraph
from utils.storage import TensorHash
from utils.attack import load_attack_class

from robust_diffusion.data import SparseGraph, count_edges_for_idx
from robust_diffusion.helper import utils as robust_utils
from robust_diffusion.train import train, train_inductive
from robust_diffusion.helper.utils import calculate_loss


from robust_diffusion.models.gcn import GCN, DenseGCN
from robust_diffusion.models.gprgnn import GPRGNN
from robust_diffusion.models.gprgnn_dense import DenseGPRGNN
from robust_diffusion.models.chebynet2 import ChebNetII
from robust_diffusion.models.gat_weighted import GAT
from robust_diffusion.models.rgnn import RGNN
from robust_diffusion.models import create_model, GPRGNN, DenseGPRGNN, ChebNetII
from robust_diffusion.models.gprgnn import GPR_prop
from robust_diffusion.helper.utils import accuracy as accuracy_metric
from copy import deepcopy

from utils.general import accuracy, model_storage_label

def load_model_class(model_name):
    if model_name not in ["GCN", "GAT", "GPRGNN", "APPNP", "ChebNetII", "SoftMedian_GDC"]:
        raise ValueError("The model config is not in the cofigs file!")
    
    if model_name == 'GCN':
        return GCN
    if model_name == 'DenseGCN':
        return DenseGCN
    if model_name == 'GAT':
        return GAT
    if model_name == 'GPRGNN':
        return GPRGNN
    if model_name == 'DenseGPRGNN':
        return DenseGPRGNN
    if model_name == 'APPNP':
        return GPRGNN
    if model_name == 'ChebNetII':
        return ChebNetII
    if model_name == 'SoftMedian_GDC':
        return RGNN
    

def create_model_instance(model_name, model_params, dataset_info, 
                          training_attr, training_adj, validation_attr, validation_adj,
                          labels, training_idx, validation_idx, 
                          test_attr, test_adj, test_mask, unlabeled_mask, inductive,
                          split_name, models_root="models", default_model_configs=None, suffix='',
                          device='cpu'):
    # Generating the model, and training it on the dataset
    if model_params is None:
        model_params = ConfigDict(default_model_configs.get(model_name))
        model_params.n_features = dataset_info.n_features
        model_params.n_classes = dataset_info.n_classes

    model = load_model_class(model_name)(**model_params.to_dict()).to(device)
    if not inductive:
        training_trace = train(
            model=model, attr=training_attr, adj=training_adj, labels=labels, idx_train=training_idx, idx_val=validation_idx, display_step=100,
            lr=model_params.lr, weight_decay=model_params.weight_decay, patience=model_params.patience,
            max_epochs=model_params.max_epochs)
    else:
        training_trace = train_inductive(
            model=model, attr_training=training_attr, attr_validation=validation_attr, 
            adj_training=training_adj, adj_validation=validation_adj,  labels_training=labels, labels_validation=labels, 
            idx_train=training_idx, idx_val=validation_idx, display_step=100,
            lr=model_params.lr, weight_decay=model_params.weight_decay, patience=model_params.patience,
            max_epochs=model_params.max_epochs)

    eval_mask = test_mask if inductive else (test_mask | unlabeled_mask)
    acc = accuracy(model, test_attr, test_adj, labels, eval_mask)

    model_storage_name = model_storage_label(model_name=model_name, 
                                            model_params=model_params, 
                                            dataset_info=dataset_info, 
                                            inductive=inductive, 
                                            split_name=split_name)

    try: 
        os.makedirs(models_root, exist_ok=True)
        torch.save(model, os.path.join(models_root, f"{model_storage_name}.pt"))
    except RuntimeError as e:
        print(f"Error saving the model: {e}")

    return {
        "model": model,
        "model_params": model_params,
        "accuracy": acc,
        "model_storage_name": model_storage_name
    }

def create_robust_model_instance(
            model_name, model_params, dataset_info, 
            training_attr, training_adj, validation_attr, validation_adj,
            labels, training_idx, validation_idx, train_attack_name,
            test_attr, test_adj, test_mask, unlabeled_mask, inductive,
            split_name, train_params, train_attack_params, val_attack_params,
            make_undirected=True,
            models_root="models", default_model_configs=None, suffix='',
            self_training=False, robust_training=False, robust_epsilon=0.0,
            validate_every=10, loss_type='ce', binary_attr=False,
            device='cpu'):
    
    training_attr = training_attr.to(device)
    training_adj = training_adj.to(device)
    validation_attr = validation_attr.to(device)
    validation_adj = validation_adj.to(device)
    labels = labels.to(device)

    # load vanilla model instance
    try:
        model_instance = load_model_instance(
        model_name=model_name, model_params=model_params, 
            test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
            split_name=split_name, dataset_info=dataset_info, inductive=inductive, models_root=models_root,
            default_model_configs=default_model_configs, device=device)
    except FileNotFoundError as e:
        print(e)
        raise ValueError("Model not found. Run training scripts to train the model.")

    model = model_instance["model"]
    acc = model_instance["accuracy"]
    model_params = model_instance["model_params"]
    model_storage_name = model_instance["model_storage_name"]

    if train_attack_name == "PGD" and model_name == "GCN":
        model = from_sparse_GCN(model, model_params)
    elif train_attack_name == "PGD" and model_name == "GPRGNN":
        model = from_sparse_GPRGNN(model, model_params)

    if isinstance(model, (GPRGNN, DenseGPRGNN)) and isinstance(model.prop1, GPR_prop) and model.prop1.norm == True: # exclude prop coeffs from weight decay
        print('Excluding GPR-GNN coefficients from weight decay as we use normalization')
        grouped_parameters = [
                {
                    "params": [p for n, p in model.named_parameters() if 'prop1.temp' != n],
                    "weight_decay": train_params['weight_decay'],
                    'lr':train_params['lr']
                },
                {
                    "params": [p for n, p in model.named_parameters() if 'prop1.temp' == n],
                    "weight_decay": 0.0,
                    'lr':train_params['lr']
                },
            ]
        optimizer = torch.optim.Adam(grouped_parameters)
    elif isinstance(model, ChebNetII):
        optimizer = torch.optim.Adam([
            {'params': model.lin1.parameters(), 'weight_decay': train_params['weight_decay'], 'lr': train_params['lr']},
            {'params': model.lin2.parameters(), 'weight_decay': train_params['weight_decay'], 'lr': train_params['lr']},
            {'params': model.prop1.parameters(), 'weight_decay': model.prop_wd, 'lr': model.prop_lr}])
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=train_params['lr'], weight_decay=train_params['weight_decay'])

    # self training
    if self_training:
        ## modify
        baseline_model = load_model_class(model_name)(**model_params.to_dict()).to(device)
        if not inductive:
            training_trace = train(
                model=baseline_model, attr=training_attr, adj=training_adj, labels=labels, idx_train=training_idx, idx_val=validation_idx, display_step=100,
                lr=model_params.lr, weight_decay=model_params.weight_decay, patience=model_params.patience,
                max_epochs=model_params.max_epochs)
        else:
            training_trace = train_inductive(
                model=baseline_model, attr_training=training_attr, attr_validation=validation_attr, 
                adj_training=training_adj, adj_validation=validation_adj,  labels_training=labels, labels_validation=labels, 
                idx_train=training_idx, idx_val=validation_idx, display_step=100,
                lr=model_params.lr, weight_decay=model_params.weight_decay, patience=model_params.patience,
                max_epochs=model_params.max_epochs)
        
        logits = baseline_model(training_attr, training_adj)
        ##
        pseudolabels = torch.argmax(logits, dim=1)
        pseudolabels[training_idx] = labels[training_idx]
        train_labels = pseudolabels
    else:
        train_labels = labels

    if inductive:
        adv_train_idx = torch.tensor(np.concatenate([training_idx, unlabeled_mask.nonzero(as_tuple=True)[0]])).to(device)
    else:
        # need to be discussed for transductive setting 
        adv_train_idx = training_idx.to(device)

    # robust training
    if robust_training:
        n_train_edges = count_edges_for_idx(training_adj, adv_train_idx) # num edges connected to train nodes
        m_train = int(n_train_edges) / 2
        n_perturbations_train = int(round(robust_epsilon * m_train))

        n_val_edges = count_edges_for_idx(validation_adj, validation_idx) # num edges connected to val nodes
        m_val = int(n_val_edges) / 2
        n_perturbations_val = int(round(robust_epsilon * m_val))

    # init attack adjs 
    adj_attacked_val = validation_adj.detach()
    adj_attacked_train = training_adj.detach()
    # init trace variables
    acc_trace_train = []
    acc_trace_val = []
    acc_trace_train_pert = []
    acc_trace_val_pert = []
    loss_trace = []
    loss_trace_val = []
    gamma_trace = []
    best_loss=np.inf

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=train_params['lr'], weight_decay=train_params['weight_decay'])

    # training loop
    for it in tqdm(range(train_params['max_epochs']), desc='Training...'):
        # Generate adversarial adjacency
        if robust_epsilon > 0:
            model.eval()
            adversary = load_attack_class(attack_name=train_attack_name)(
                attr=training_attr.to(device), adj=training_adj.to(device), labels=train_labels.to(device), 
                model=model, idx_attack=np.array(adv_train_idx.cpu()),
                                    device=device, data_device=device, binary_attr=False,
                                    make_undirected=make_undirected, **train_attack_params)
            
            adversary.attack(n_perturbations_train)

            adj_pert, pert_attr = adversary.get_pertubations()
            adj_attacked_train = adj_pert
            training_attr = pert_attr
            del adversary

        # train step
        model.train()
        optimizer.zero_grad()
        logits = model(training_attr, adj_attacked_train)
        loss = calculate_loss(logits[adv_train_idx], train_labels[adv_train_idx], loss_type)
        loss.backward()
        optimizer.step()        
        train_accuracy = accuracy_metric(logits.cpu(), train_labels.cpu(), adv_train_idx.cpu())
        acc_trace_train_pert.append(train_accuracy)
        if isinstance(model, GPRGNN) and isinstance(model.prop1, GPR_prop):
            gamma_trace.append(model.prop1.temp.detach().cpu())
        loss_trace.append(loss.item())

        # validation step 
        if it % validate_every == 0:
            if robust_epsilon > 0: 
                torch.cuda.empty_cache()
                adversary_val = load_attack_class(attack_name=train_attack_name)(attr=validation_attr, adj=validation_adj, labels=labels, model=model, idx_attack=np.array(validation_idx),
                        device=device, data_device=device, binary_attr=binary_attr,
                        make_undirected=make_undirected, **val_attack_params)

                model.eval()
                adversary_val.attack(n_perturbations_val)
                adj_pert, pert_attr = adversary_val.get_pertubations()
                adj_attacked_val = adj_pert
                validation_attr = pert_attr
                del adversary_val

            with torch.no_grad():
                model.eval()
                logits_val = model(validation_attr, adj_attacked_val)
                loss_val = calculate_loss(logits_val[validation_idx], labels[validation_idx], loss_type)
                # save val statistic
                loss_trace_val.append(loss_val.item())
                val_accuracy = accuracy_metric(logits_val.cpu(), labels.cpu(), validation_idx)
                acc_trace_val_pert.append(val_accuracy)

                # log clean accuracy on train graph
                logits_clean_train = model(training_attr, training_adj)
                train_accuracy_clean = accuracy_metric(logits_clean_train.cpu(), train_labels.cpu(), training_idx)
                acc_trace_train.append(train_accuracy_clean)
                # log clean accuracy on val graph
                logits_clean_val = model(validation_attr, validation_adj)
                val_accuracy_clean = accuracy_metric(logits_clean_val.cpu(), labels.cpu(), validation_idx)
                acc_trace_val.append(val_accuracy_clean)
                # print output
                if isinstance(model, GPRGNN) and isinstance(model.prop1, GPR_prop):
                    print(f'model gammas: {model.prop1.normalize_coefficients().detach().cpu()}')

                print(f'train acc (pert/clean): {train_accuracy} / {train_accuracy_clean}')
                print(f'val acc (pert/clean): {val_accuracy} / {val_accuracy_clean}')
        
        # save train statistics
        loss_trace.append(loss.item())

        # save new best model and break if patience is reached
        if loss_val < best_loss:
            best_loss = loss_val
            best_epoch = it
            best_state = {key: value.cpu() for key, value in model.state_dict().items()}
        else:
            if it >= best_epoch + train_params['patience']:
                break

    # restore the best validation state
    model.load_state_dict(best_state)
    model.eval()
    eval_mask = test_mask if inductive else (test_mask | unlabeled_mask)
    clean_acc = accuracy(model, test_attr, test_adj, labels, eval_mask)

    model_storage_name = model_storage_label(model_name=model_name,
                                            model_params=model_params,
                                            dataset_info=dataset_info,
                                            inductive=inductive,
                                            self_training=self_training,
                                            robust_training=robust_training,
                                            train_attack_name=train_attack_name,
                                            robust_epsilon=robust_epsilon,
                                            split_name=split_name)

    try: 
        os.makedirs(models_root, exist_ok=True)
        torch.save(model, os.path.join(models_root, f"{model_storage_name}.pt"))
    except RuntimeError as e:
        print(f"Error saving the model: {e}")

    return {
        "model": model,
        "model_params": model_params,
        "accuracy": acc,
        "clean_accuracy": clean_acc,
        "acc_trace_val_pert": acc_trace_val_pert,
        "model_storage_name": model_storage_name
    }

def load_model_instance(model_name, model_params, dataset_info,
                        test_attr, test_adj, labels, test_mask, unlabeled_mask,
                        split_name, inductive=False,
                        models_root="models",
                        default_model_configs=None, suffix='', device='cpu'):
    if default_model_configs is None:
        raise ValueError("default_model_configs is required to load the model")
    
    if model_params is None:
        model_params = ConfigDict(default_model_configs.get(model_name))
        model_params.n_features = dataset_info.n_features
        model_params.n_classes = dataset_info.n_classes
    model = load_model_class(model_name)(**model_params.to_dict()).to(device)
    
    # TODO: add a check to see if the model config hash is same as the hash of the current model config
    model_storage_name = model_storage_label(model_name=model_name,
                                            model_params=model_params,
                                            dataset_info=dataset_info,
                                            inductive=inductive,
                                            split_name=split_name)
    
    try:
        model = torch.load(os.path.join(models_root, f"{model_storage_name}.pt")).to(device)
    except RuntimeError as e:
        print(f"Error loading the model: {e}")
        return None
    
    eval_mask = test_mask if inductive else (test_mask | unlabeled_mask)
    acc = accuracy(model, test_attr, test_adj, labels, eval_mask)
    return {
        "model": model,
        "model_params": model_params,
        "model_storage_name": model_storage_name,
        "accuracy": acc
    }
    
def load_robust_model_instance(model_name, model_params, dataset_info,
                        test_attr, test_adj, labels, test_mask, unlabeled_mask,
                        split_name, inductive=False,
                        models_root="models", self_training=True, robust_training=True, train_attack_name='PRBCD', robust_epsilon=0.2,
                        default_model_configs=None, suffix='', device='cpu'):
    
    if default_model_configs is None:
        raise ValueError("default_model_configs is required to load the model")
    
    if model_params is None:
        model_params = ConfigDict(default_model_configs.get(model_name))
        model_params.n_features = dataset_info.n_features
        model_params.n_classes = dataset_info.n_classes
    model = load_model_class(model_name)(**model_params.to_dict()).to(device)
    
    # TODO: add a check to see if the model config hash is same as the hash of the current model config
    model_storage_name = model_storage_label(model_name=model_name,
        model_params=model_params,
        dataset_info=dataset_info,
        inductive=inductive,
        self_training=self_training,
        robust_training=robust_training,
        train_attack_name=train_attack_name,
        robust_epsilon=robust_epsilon,
        split_name=split_name)
    
    try:
        model = torch.load(os.path.join(models_root, f"{model_storage_name}.pt")).to(device)
    except RuntimeError as e:
        print(f"Error loading the model: {e}")
        return None
    
    eval_mask = test_mask if inductive else (test_mask | unlabeled_mask)
    acc = accuracy(model, test_attr, test_adj, labels, eval_mask)

    return {
        "model": model,
        "model_params": model_params,
        "model_storage_name": model_storage_name,
        "accuracy": acc,
        "clean_accuracy": acc
    }


def from_sparse_GCN(sparse_model, args):
    dense_GCN = DenseGCN(**args)
    dense_GCN.activation = deepcopy(sparse_model.activation)
    dense_GCN.layers[0].gcn_0._linear = deepcopy(sparse_model.layers[0].gcn_0.lin)
    dense_GCN.layers[0].activation_0 = deepcopy(sparse_model.layers[0].activation_0)
    dense_GCN.layers[0].dropout_0 = deepcopy(sparse_model.layers[0].dropout_0)
    dense_GCN.layers[1].gcn_1._linear = deepcopy(sparse_model.layers[1].gcn_1.lin)
    return dense_GCN

def from_sparse_GPRGNN(sparse_model, args):
    dense_GPRGNN = DenseGPRGNN(**args)
    dense_GPRGNN.prop1.temp = deepcopy(sparse_model.prop1.temp)
    dense_GPRGNN.lin1 = deepcopy(sparse_model.lin1)
    dense_GPRGNN.lin2 = deepcopy(sparse_model.lin2)
    return dense_GPRGNN