from omegaconf import OmegaConf
import datetime
from loguru import logger
import torch
import wandb
import time
import tqdm

from utils.utils import *
from utils.model_utils import *
from utils.training_utils import EarlyStoppingModule, pairwise_ranking_loss, evaluate_model
from utils.dataset_loader_mini import SubgraphIsomorphismDataset
import pyfiglet


def train(conf, gmn_config):

    train_dataset = SubgraphIsomorphismDataset(conf,mode="train")
    val_dataset = SubgraphIsomorphismDataset(conf,mode="val")
    
    conf.actual_max_node_set_size = train_dataset.max_node_set_size
    conf.actual_max_edge_set_size = train_dataset.max_edge_set_size
    
    es = EarlyStoppingModule(conf.base_dir, conf.task.name, patience=conf.training.patience, logger=logger)
 
    logger.info(es.__dict__)
    logger.info(f"This uses the {conf.model.classPath}.{conf.model.name} model")
    model = get_class(f"{conf.model.classPath}.{conf.model.name}")(conf, gmn_config).to(conf.training.device)
    logger.info(model)
    if conf.training.wandb_watch:
        wandb.watch(model, log_freq=1)
    logger.info(f"no. of params in model: {sum([p.numel() for p in model.parameters()])}")
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=conf.training.learning_rate,
        weight_decay=conf.training.weight_decay,
    )
    if not conf.training.resume:
        es.save_initial_state(
            {
                "model_state_dict": model.state_dict(),
                "optim_state_dict": optimizer.state_dict(),
                'rng_state': torch.get_rng_state(),
                'cuda_rng_state': torch.cuda.get_rng_state(),
                'np_rng_state': np.random.get_state(),
                'random_state': random.getstate(),
                'patience': es.patience,
                'best_scores': es.best_scores,
                'num_bad_epochs': es.num_bad_epochs,
                'should_stop_now': es.should_stop_now,
            }
        )
    
    # if conf.training.overwrite:
    #     checkpoint = None
    # else:
    #     checkpoint = es.load_latest_model() 
    
    # if not checkpoint:
    #     es.save_initial_model(model)
    #     run = 0
    # else:
    #     if es.should_stop_now:
    #         logger.info("Training has been completed. This logfile can be deleted.")
    #         return
    #     else:
    #         model.load_state_dict(checkpoint["model_state_dict"])
    #         optimizer.load_state_dict(checkpoint["optim_state_dict"])
    #         run = checkpoint["epoch"] + 1    
    
    best_val_ap = 0
    best_val_map = 0
    run = 0 
    
    if conf.training.resume:
        banner = pyfiglet.figlet_format('Resuming training', font="slant", justify="center")
        logger.info(banner)
        checkpoint = es.load_latest_model() #torch.load(correct_file[0])
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        run = checkpoint['epoch'] + 1
        for _ in range(run):
            train_dataset.create_stratified_batches() #generating and discarding in place shuffles for reproducibiltiy
        # try:
        rng_state = checkpoint['rng_state']
        cuda_rng_state = checkpoint['cuda_rng_state']
        np_rng_state = checkpoint['np_rng_state']
        random_state = checkpoint['random_state']
        torch.set_rng_state(rng_state)
        torch.cuda.set_rng_state(cuda_rng_state)
        np.random.set_state(np_rng_state)
        random.setstate(random_state)
        # set dataloader 
            
        # except KeyError as _:
            # logger.info(f"Did not find rng_state, using default seed {conf.training.seed}")

        try:
            best_val_ap = checkpoint['best_val_ap'].to(conf.training.device)
            best_val_map = checkpoint['best_val_map'].to(conf.training.device)
        except:
            logger.info(f"Could not load best_val_ap, best_val_map from checkpoint. Setting to 0")
        # model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optim_state_dict'])
        run = checkpoint['epoch'] + 1
        logger.info(f"Resuming training from epoch {run} with best val map: {best_val_map:.6f} and patience: {es.patience}")

    
    
    
    
    
    while conf.training.run_till_early_stopping and run < conf.training.num_epochs:
        model.train()
    
        num_batches = train_dataset.create_stratified_batches()
        epoch_loss = 0
        training_start_time = time.time()
        for batch_idx in tqdm.tqdm(range(num_batches)):
            batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes, labels = train_dataset.fetch_batch_by_id(batch_idx)
        
            optimizer.zero_grad()
            prediction = model(batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes) #, batch_adj_matrices)

            predictions_for_positives = prediction[labels > 0.5]
            predictions_for_negatives = prediction[labels < 0.5]
            losses = pairwise_ranking_loss(
                predictions_for_positives.unsqueeze(1),
                predictions_for_negatives.unsqueeze(1),
                conf.training.margin
            )
            losses.backward()
            optimizer.step()
            epoch_loss += losses.item()
        
        epoch_training_time = time.time() - training_start_time
        logger.info(f"Run: {run} train loss: {epoch_loss} Time: {epoch_training_time}")    

        model.eval()
        validation_start_time = time.time()
        ap_score, map_score = evaluate_model(model, val_dataset)
        epoch_validation_time = time.time() - validation_start_time
        logger.info(f"Run: {run} VAL ap_score: {ap_score} map_score: {map_score} Time: {epoch_validation_time}")

        
        state_dict = {
            "model_state_dict": model.state_dict(),
            "optim_state_dict": optimizer.state_dict(),
            "epoch": run,
            "best_val_ap": best_val_ap,
            "best_val_map": best_val_map,
            "val_ap_score": ap_score,
            "val_map_score": map_score,
            'rng_state': torch.get_rng_state(),
            'cuda_rng_state': torch.cuda.get_rng_state(),
            'np_rng_state': np.random.get_state(),
            'random_state': random.getstate(),
            'patience': es.patience,
            'best_scores': es.best_scores,
            'num_bad_epochs': es.num_bad_epochs,
            'should_stop_now': es.should_stop_now,
        }
        
        # if conf.training.run_till_early_stopping:
        state_dict =  es.check([map_score], state_dict)
        best_val_ap = state_dict["best_val_ap"]
        best_val_map = state_dict["best_val_map"]    
        logger.info(f"Run: {run} best_val_ap: {best_val_ap} best_val_map: {best_val_map}")     
        wandb.log(
            {
            "train_loss": epoch_loss,
            "train_time": epoch_training_time,
            "val_ap_score": ap_score,
            "val_map_score": map_score,
            "best_val_ap": best_val_ap,
            "best_val_map": best_val_map,
            },
            # step = run
        )
        if es.should_stop_now:
            break
        run += 1
        
        
    ckpt = es.load_best_model()
    model.load_state_dict(ckpt["model_state_dict"])
    test_dataset = SubgraphIsomorphismDataset(conf,mode="test")
    model.eval()
    test_start_time = time.time()
    test_ap_score, test_map_score = evaluate_model(model, test_dataset)
    test_time = time.time() - test_start_time
    logger.info(f"Run: run TEST ap_score: {test_ap_score} map_score: {test_map_score} Time: {test_time}", )
    
    wandb.log(
        {
        "test_ap_score": test_ap_score,
        "test_map_score": test_map_score,
        }        
    )  

    
            


if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.rel_mode}/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    # run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    
    # open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file

    name_removal_set = {'classPath', 'name'}
    task_name = ",".join("{}={}".format(i, k ) for i, k  in conf.model.items() if (i not in name_removal_set))

    # conf.task.name = f"{conf.model.name}_{conf.dataset.name}_numC={conf.dataset.aug_num_cgraphs}_MinR={conf.dataset.MinR}_MaxR={conf.dataset.MaxR}_{task_name}_{run_time}"
    conf.task.name = f"{conf.model.name}_{conf.dataset.name}_numC={conf.dataset.aug_num_cgraphs}_MinR={conf.dataset.MinR}_MaxR={conf.dataset.MaxR}_{task_name},stemp={conf.training.sinkhorn_temp},margin={conf.training.margin}"
    logger.info(f"Task name: {conf.task.name}")
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))

    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
    
    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.wandb_group,
        config={
            'learning_rate': conf.training.learning_rate,
            'weight_decay': conf.training.weight_decay,
            'num_epochs': conf.training.num_epochs,
            'seed': conf.training.seed,
            'batch_size': conf.training.batch_size,
            #TODO: add ground truth type here (SubIso)
        }
    )

    set_seed(conf.training.seed)
    gmn_config = modify_gmn_main_config(get_default_gmn_config(conf), conf, logger)
    train(conf, gmn_config)