import os
import random
import yaml
from train_model import MODEL_MAP, StopOnMinLR
import data_utils as topo_data
import models as models
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import seed_everything

torch.set_float32_matmul_precision(precision='high')

def read_yaml(yaml_file_path):
    with open(f'{yaml_file_path}','r') as f:
        output = yaml.safe_load(f)
    return output

def dump_yaml(data, yaml_file_path):
    with open(yaml_file_path, 'w') as yaml_file:
        yaml.dump(data, yaml_file, default_flow_style=False)

hyp_params = {
    "model": ["TopoGNN"],
    #"dataset": ["ENZYMES"],
    #"dataset": ["DBLP"],
    #"dataset": ["REDDIT-BINARY"],
    "dataset": ['REDDIT-5K'],
    #"dataset": ["Necklaces"],
    #"dataset": ["PATTERN"],
    #"dataset": ["CLUSTER"],
    #"dataset": ["MNIST"],
    #"dataset": ["DD"],
    #"dataset": ["PROTEINS"],
    #"dataset": ["PROTEINS_full"],
    #"dataset": ['IMDB-BINARY'],
    "print_stats": [False],
    "training_seed": [42],
    "num_repeats": [20],
    "max_epochs": [1000],
    "paired": [False],
    "merged": [False],

    "use_node_attributes": [True],
    "fold": [0],
    "seed": [42],
    "batch_size": [64],
    "legacy": [True],
    "benchmark_idx": [True],
    "lift_to_simplex": [True],
    "max_simplex_dim": [2],
    #"max_simplex_dim": [2, 3], ########### reprocess dataset if not available

    "hidden_dim": [32, 64, 128, 256],
    "depth": [3],
    "lr": [0.005, 0.001, 0.0005, 0.0001],
    "lr_patience": [15, 25],
    "min_lr": [0.00001],
    "dropout_p": [0.0, 0.1, 0.2],
    "GIN": [False],
    "GAT": [True],
    "GCN": [False],
    "train_eps": [True, False],
    "batch_norm": [True, False],
    "residual": [True, False],
    "save_filtration": [False],
    "add_mlp": [True, False],
    "weight_decay": [0.0, 0.1, 0.2],
    "dropout_input_p": [0.0, 0.1, 0.2],
    "dropout_edges_p": [0.0, 0.1, 0.2],
    "num_heads_gnn": [1, 2, 3],
    "strong_reg_check": [False],

    "filtration_hidden": [8, 16, 24, 32, 64],
    "num_filtrations": [2, 4, 8, 16],
    "tanh_filtrations": [True, False],
    "relu_filtrations": [True, False],
    "deepset_type": ["full", "shallow", "linear"],
    "full_deepset_highdims": [True, False],
    "swap_bn_order": [True, False],
    "dim1": [True],
    "higher_dims": [2], ################
    #"higher_dims": [2, 3, 4],
    "num_coord_funs": [1, 2, 3, 4],
    "togl_position": [0, 1, 2],
    "residual_and_bn": [True, False],
    "share_filtration_parameters": [True, False],
    "separate_filtration_functions": [True, False],
    "fake": [False],
    "deepset": [True],
    "dim0_out_dim": [8, 16, 32, 64],
    "dim1_out_dim": [8, 16, 32, 64],
    "higher_dims_out_dim": [8, 16, 32, 64],
    "dist_dim1": [True, False],
    "dist_dimh": [True, False],
    "clique_persistence": [True], #####WHAT
    "mlp_combine_dims_clique_persistence": [True, False],
    "aggregation_fn": ["mean", "max", "sum"],
}


def get_random_config():
    config = {}
    for arg, posssible_values in hyp_params.items():
        config[arg] = random.choice(posssible_values)

    if config["dist_dimh"]:
        config["dist_dim1"] = True
    # fix potential incompatibilities
    #gnn_models = ["GIN", "GAT", "GCN"]
    #gnn_model_choice = random.choice(gnn_models)
    #if gnn_model_choice != "GCN":
    #    config[gnn_model_choice] = True

    if config["tanh_filtrations"] == config["relu_filtrations"] == True:
        config["tanh_filtrations"] = config["relu_filtrations"] = False
        choice = random.choice(["tanh_filtrations", "relu_filtrations"])
        config[choice] = True

    if config["togl_position"] is not None and config["togl_position"] > config["depth"]:
        config["togl_position"] = random.choice([i for i in range(config["depth"]+1)])

    config["higher_dims"] = [i for i in range(2, config["higher_dims"]+1)]
    config["higher_dims_out_dim"] = [config["higher_dims_out_dim"]] * len(config["higher_dims"])

    return config


if __name__ == "__main__":
    config = get_random_config()
    seed_everything(config["training_seed"])
    model_cls = MODEL_MAP[config["model"]]
    dataset_cls = topo_data.get_dataset_class(**config)

    res_repeats = []
    REUSE_CONFIG = False
    best_acc = 0
    for repeat in range(config["num_repeats"]):
        if repeat > 0:
            config = get_random_config()

        import pickle
        if REUSE_CONFIG:
            config["num_repeats"] = repeat+1
            with open("tuning_config.pt", "rb") as f:
                config = pickle.load(f)
        else:
            with open("tuning_config.pt", "wb") as f:
                pickle.dump(config, f)

        dataset = dataset_cls(**config)
        dataset.prepare_data()

        model = model_cls(
            **config,
            num_node_features=dataset.node_attributes,
            num_classes=dataset.num_classes,
            task=dataset.task
        )
        print('Running with hyperparameters:')
        print(model.hparams)

        # Loggers and callbacks
        #wandb_logger = WandbLogger(
        #name=f"{config['model']}_{config['dataset']}",
        #project="topo_gnn",
        #log_model=True,
        #tags=[config['model'], config['dataset']],
        #save_dir=os.path.join(wandb_logger.experiment.dir, f"{config['model']}_{config['dataset']}")
        #)

        stop_on_min_lr_cb = StopOnMinLR(config["min_lr"])
        lr_monitor = LearningRateMonitor('epoch')
        checkpoint_cb = ModelCheckpoint(
            dirpath=os.path.join("results", f"{config['model']}_{config['dataset']}"),
            #dirpath=wandb_logger.experiment.dir,
            monitor='val_loss',
            mode='min',
            verbose=True
        )

        GPU_AVAILABLE = torch.cuda.is_available() and torch.cuda.device_count() > 0
        trainer = pl.Trainer(
            accelerator='auto',
            devices = -1,
            strategy='ddp_find_unused_parameters_true',
            logger=CSVLogger(os.path.join("results", f"{config['model']}_{config['dataset']}"), name="log"),
            #logger=wandb_logger,
            log_every_n_steps=5,
            max_epochs=config['max_epochs'],
            callbacks=[stop_on_min_lr_cb, checkpoint_cb, lr_monitor],
            #profiler="advanced"
        )

        trainer.fit(model, datamodule=dataset)

        print("Performance on validation set:")
        val_results = trainer.validate(dataloaders=dataset.val_dataloader(), ckpt_path='last')[0]
        print(val_results)

        print("Performance on test set:")
        test_results = trainer.test(dataloaders=dataset.test_dataloader(), ckpt_path='last')[0]
        res_repeats.append(test_results["test_acc"])

        # Just for interest see if loading the state with lowest val loss actually
        # gives better generalization performance.
        """
        checkpoint_path = checkpoint_cb.best_model_path
        trainer2 = pl.Trainer(logger=False)

        model = model_cls.load_from_checkpoint(
            checkpoint_path)
        val_results = trainer2.test(
            model,
            dataloaders=dataset.val_dataloader()
        )[0]

        val_results = {
            name.replace('test', 'val'): value
            for name, value in val_results.items()
        }

        test_results = trainer2.test(
            model,
            dataloaders=dataset.test_dataloader()
        )[0]

        for name, value in {**test_results}.items():
            wandb_logger.experiment.summary['restored_' + name] = value
            """
    print("--- Results over repeats:")
    print("Test Acc:", np.array(res_repeats).mean(), "+-", np.array(res_repeats).std())
    print(res_repeats)



