import os
import yaml
from ml_collections import ConfigDict
from tqdm import tqdm

import torch
import torch_geometric

from utils.data import load_dataset, make_dataset_splits, load_dataset_splits, check_dataset_valid
from utils.split import SplitManager, node_induced_subgraph
from utils.storage import TensorHash
from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance
from utils.attack import load_attack_class

from robust_diffusion.data import SparseGraph
from robust_diffusion.data import count_edges_for_idx
from robust_diffusion.helper import utils as robust_utils
from robust_diffusion.train import train, train_inductive

from sacred import Experiment

experiment = Experiment("VanillaTraining")

@experiment.config
def default_config():
    ## Experiment configs
    dataset_name = "cora_ml"
    model_name = "GCN"
    recreate_splits = False
    n_splits = 10

    training_split = None
    validation_split = None
    training_split_type = None
    validation_split_type = None
    test_split = None
    test_split_type = None
    # TODO: add unlabeled split

    model_params = None
    epsilon = 0.1

    attack_name = "PRBCD"
    attack_params = None

    inductive = False

@experiment.automain
def run(
    dataset_name, model_name, recreate_splits, n_splits, 
    training_split, validation_split, training_split_type, validation_split_type, test_split, test_split_type,
    model_params, epsilon, attack_name, attack_params, inductive):


    ## Loading general configs (like dataset_root, etc.) and initial parameters
    general_config = yaml.safe_load(open("conf/general-config.yaml"))
    default_dataset_configs = yaml.safe_load(open("conf/data-configs.yaml")).get("configs").get("default")
    default_model_configs = yaml.safe_load(open("conf/model-configs.yaml")).get("configs")
    default_attack_configs = yaml.safe_load(open("conf/attack-configs.yaml")).get("configs")

    # extracting configs 
    dataset_root = general_config.get("dataset_root", "data/")
    splits_root = general_config.get("splits_root", "splits/")
    models_root = general_config.get("models_root", "models/")
    results_root = general_config.get("results_root", "results/")
    reports_root = general_config.get("reports_root", "reports/")
        
    device = 'cuda' if torch.cuda.is_available() else 'cpu'


    print("Experiment Started")
    # Trains the specified model on the given graph and saves the model artifacts, and the splits.

    print("Loading dataset =", dataset_name)

    try:
        dataset_splits = [split_record for split_record in os.listdir(splits_root) 
                          if split_record.split("-")[0] == dataset_name and 
                          check_dataset_valid(split_record=split_record, 
                                              training_split=training_split, validation_split=validation_split, 
                                              training_split_type=training_split_type, validation_split_type=validation_split_type,
                                              test_split=test_split, test_split_type=test_split_type, splits_root=splits_root)]
    except FileNotFoundError:
        dataset_splits = []
    creating_splits = max(n_splits - len(dataset_splits), 0)

    # creating remaining needed dataset splits
    print(f"Found {len(dataset_splits)} splits, creating {creating_splits} more splits")
    for i in tqdm(range(creating_splits)):
        torch.cuda.empty_cache()
        data = make_dataset_splits(dataset_name=dataset_name, 
                                training_split=training_split, validation_split=validation_split, 
                                training_split_type=training_split_type, validation_split_type=validation_split_type, 
                                test_split=test_split, test_split_type=test_split_type,
                                inductive=inductive, 
                                default_dataset_configs=default_dataset_configs, dataset_root=dataset_root, splits_root=splits_root, device=device)

    dataset_splits = [split_record for split_record in os.listdir(splits_root) if split_record.split("-")[0] == dataset_name][:n_splits]
    print(f"Training {model_name} model on {dataset_name} dataset for {n_splits} splits")

    accs = []
    for split_file in tqdm(dataset_splits):
        split_code = split_file.split("-")[1].replace(".pt", "")

        data = load_dataset_splits(
            dataset_name, split_code, inductive=inductive, 
            dataset_root=dataset_root, splits_root=splits_root, device=device)

        training_attr = data["training_attr"]
        training_adj = data["training_adj"]
        validation_attr = data["validation_attr"]
        validation_adj = data["validation_adj"]
        labels = data["labels"]
        training_idx = data["training_idx"]
        validation_idx = data["validation_idx"]
        test_attr = data["test_attr"]
        test_adj = data["test_adj"]
        unlabeled_mask = data["unlabeled_mask"]
        test_mask = data["test_mask"]
        dataset_info = data["dataset_info"]
        split_name = data["split_name"]
        data_config = data["config"]

        try:
            model_instance = load_model_instance(
                model_name=model_name, model_params=model_params, 
                test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                split_name=split_name, dataset_info=dataset_info, inductive=inductive, models_root=models_root,
                default_model_configs=default_model_configs, device=device)
        except FileNotFoundError as e:
            print(e)
            print("Creating model from scratch")
            model_instance = create_model_instance(
                model_name=model_name, model_params=model_params, dataset_info=dataset_info, 
                training_attr=training_attr, training_adj=training_adj, 
                validation_attr=validation_attr, validation_adj=validation_adj,
                labels=labels, training_idx=training_idx, validation_idx=validation_idx,
                test_attr=test_attr, test_adj=test_adj, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                inductive=inductive, split_name=split_name,
                models_root=models_root, 
                default_model_configs=default_model_configs, 
                device=device)

        model = model_instance["model"]
        acc = model_instance["accuracy"]
        model_params = model_instance["model_params"]
        model_storage_name = model_instance["model_storage_name"]
        accs.append(acc)

    acc_mean = torch.mean(torch.tensor(accs))
    acc_std = torch.std(torch.tensor(accs))

    print(f"Mean accuracy: {acc_mean}, std: {acc_std}")

    print("Experiment Finished")

