import os
import random
import yaml
import argparse
import wandb
from train_model import MODEL_MAP, StopOnMinLR
#from test import test
import data_utils as topo_data
import models as models
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
#from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning import seed_everything

torch.set_float32_matmul_precision(precision='high')
hparams_dict = {
    "GAT": True,
    "GCN": False,
    "GIN": False,
    "add_mlp": False,
    "aggregation_fn": "max",
    "batch_norm": True,
    "batch_size": 32,
    "benchmark_idx": True,
    "clique_persistence": True,
    "dataset": "ENZYMES",
    "deepset": True,
    "deepset_type": "full",
    "depth": 3,
    "dim0_out_dim": 32,
    "dim1": True,
    "dim1_out_dim": 8,
    "dist_dim1": True,
    "dist_dimh": True,
    "dropout_edges_p": 0.2,
    "dropout_input_p": 0.2,
    "dropout_p": 0.0,
    "fake": False,
    "filtration_hidden": 64,
    "fold": 0,
    "full_deepset_highdims": True,
    "hidden_dim": 256,
    "higher_dims": [2],
    "higher_dims_out_dim": [64],
    "legacy": True,
    "lift_to_simplex": True,
    "lr": 0.001,
    "lr_patience": 10,
    "max_epochs": 1000,
    "max_simplex_dim": 2,
    "merged": False,
    "min_lr": 0.0001,
    "mlp_combine_dims_clique_persistence": True,
    "model": "TopoGNN",
    "num_coord_funs": 4,
    "num_filtrations": 2,
    "num_heads_gnn": 3,
    "paired": False,
    "print_stats": False,
    "relu_filtrations": False,
    "residual": False,
    "residual_and_bn": False,
    "save_filtration": False,
    "seed": 42,
    "separate_filtration_functions": False,
    "share_filtration_parameters": False,
    "strong_reg_check": False,
    "swap_bn_order": False,
    "tanh_filtrations": False,
    "togl_position": 1,
    "train_eps": False,
    "training_seed": 42,
    "use_node_attributes": False,
    "weight_decay": 0,
    
}

def read_yaml(yaml_file_path):
    with open(f'{yaml_file_path}','r') as f:
        output = yaml.safe_load(f)
    return output

def dump_yaml(data, yaml_file_path):
    for k, _ in hparams_dict.items():
        hparams_dict[k] = data[k]
    with open(yaml_file_path, 'w') as yaml_file:
        yaml.dump(hparams_dict, yaml_file, default_flow_style=False)

# fix potential incompatibilities
def is_possible_config(config):
    flag = False
    #if config["dist_dimh"]:
    #    if not config["dist_dim1"]:
    #        flag = True

    #gnn_models = ["GIN", "GAT", "GCN"]
    #gnn_model_choice = random.choice(gnn_models)
    #if gnn_model_choice != "GCN":
    #    config[gnn_model_choice] = True

    #if config["tanh_filtrations"] == config["relu_filtrations"] == True:
    #    flag = True
        #config["tanh_filtrations"] = config["relu_filtrations"] = False
        #choice = random.choice(["tanh_filtrations", "relu_filtrations"])
        #config[choice] = True

    #if config["togl_position"] is not None and config["togl_position"] > config["depth"]:
    #    flag = True
    #    config["togl_position"] = random.choice([i for i in range(config["depth"]+1)])

    #config["higher_dims"] = [i for i in range(2, config["higher_dims"]+1)]
    #config["higher_dims_out_dim"] = [config["higher_dims_out_dim"]] * len(config["higher_dims"])

    return flag

def model_pipline():
    wandb.init()
    temp_config = wandb.config
    for k, _ in hparams_dict.items():
        hparams_dict[k] = temp_config[k]
    config = hparams_dict.copy()
    
    flag = is_possible_config(config)
    seed_everything(config["training_seed"])
    model_cls = MODEL_MAP[config["model"]]
    dataset_cls = topo_data.get_dataset_class(**config)
    dataset = dataset_cls(**config)
    dataset.prepare_data()

    model = model_cls(
        **config,
        num_node_features=dataset.node_attributes,
        num_classes=dataset.num_classes,
        task=dataset.task
    )
    #print('Running with hyperparameters:')
    #print(model.hparams)

    # Loggers and callbacks
    wandb_logger = WandbLogger(
    name=f"{config['model']}_{config['dataset']}",
    project="topo_gnn",
    log_model=True,
    tags=[config['model'], config['dataset']],
    #save_dir=os.path.join(wandb_logger.experiment.dir, f"{config['model']}_{config['dataset']}")
    )

    stop_on_min_lr_cb = StopOnMinLR(config["min_lr"])
    lr_monitor = LearningRateMonitor('epoch')
    checkpoint_cb = ModelCheckpoint(
        #dirpath=os.path.join("results", f"{config['model']}_{config['dataset']}"),
        dirpath=wandb_logger.experiment.dir,
        monitor='val_loss',
        mode='min',
        verbose=True
    )
    trainer = pl.Trainer(
        accelerator='auto',
        #logger=CSVLogger(os.path.join("results", f"{config['model']}_{config['dataset']}"), name="log"),
        logger=wandb_logger,
        log_every_n_steps=5,
        max_epochs=config['max_epochs'],
        callbacks=[stop_on_min_lr_cb, checkpoint_cb, lr_monitor],
        #profiler="advanced"
    )
    if not flag:
        trainer.fit(model, datamodule=dataset)
        print("Performance on validation set:")
        val_results = trainer.validate(dataloaders=dataset.val_dataloader(), ckpt_path='last')[0]
        print(val_results)
        
        global min_val_loss
        if val_results["val_loss"] < min_val_loss:
            min_val_loss = val_results["val_loss"]
            dump_yaml(
                config,    f"./best_config/hparams_{config['dataset']}_GAT_{config['GAT']}_GCN_{config['GCN']}_GIN_{config['GIN']}_use_node_attributes_{config['use_node_attributes']}.yaml")
            
        print("Performance on test set:")
        test_results = trainer.test(dataloaders=dataset.test_dataloader(), ckpt_path='last')[0]

        #res_repeats.append(test_results["test_acc"])
    
        # Just for interest see if loading the state with lowest val loss actually
        # gives better generalization performance.
        """
        checkpoint_path = checkpoint_cb.best_model_path
        trainer2 = pl.Trainer(logger=False)
    
        model = model_cls.load_from_checkpoint(
            checkpoint_path)
        val_results = trainer2.test(
            model,
            dataloaders=dataset.val_dataloader()
        )[0]
    
        val_results = {
            name.replace('test', 'val'): value
            for name, value in val_results.items()
        }

        test_results = trainer2.test(
            model,
            dataloaders=dataset.test_dataloader()
        )[0]

        for name, value in {**test_results}.items():
            wandb_logger.experiment.summary['restored_' + name] = value
                """
        #print("--- Results over repeats:")
        #print("Test Acc:", np.array(res_repeats).mean(), "+-", np.array(res_repeats).std())
        #print(res_repeats)
    
    else:
        #trainer.fit(model, datamodule=dataset)
        wandb_logger.experiment.log({"val_acc": 0, "test_acc": 0})

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, help='dataset')
    args = parser.parse_args()
    use_node_attributes = True
    gnn_models = ['GAT','GCN', 'GIN']
    act_funcs = ['tanh_filtrations', 'relu_filtrations']
    dist_dimhs = [True, False]
    
    sweep_config = read_yaml('./tuning.yml')
    sweep_config["parameters"].update({'dataset': {'value': args.dataset}})
    sweep_config["parameters"].update({'use_node_attributes': {'value': use_node_attributes}})
    
    global min_val_loss
    for gnn_model in gnn_models:
        min_val_loss = 1000
        for act_func in act_funcs:
            for dist_dimh in dist_dimhs:
                if dist_dimh:
                    sweep_config["parameters"].update({'dist_dim1': {'value': True}}) 
                sweep_config["parameters"].update({gnn_model: {'value': True}})
                if act_funcs is not None:
                    sweep_config["parameters"].update({act_func: {'value': True}})
                sweep_config["parameters"].update({'dist_dimh': {'value': dist_dimh}})
                
                sweep_id = wandb.sweep(sweep=sweep_config, project="topo_gnn_sweeping_"+args.dataset)
                wandb.agent(sweep_id, function=model_pipline, count=10)
                
                sweep_config = read_yaml('./tuning.yml')
                sweep_config["parameters"].update({'dataset': {'value': args.dataset}})
                sweep_config["parameters"].update({'use_node_attributes': {'value': use_node_attributes}})
                
        #config_path = f"{./best_config/best_config_{sweep_config["parameters"]['dataset']}_GAT_{sweep_config["parameters"]['GAT']}_GCN_{sweep_config["parameters"]['GCN']}_GIN_{sweep_config["parameters"]['GIN']}.yaml"
        

