import torch
from loguru import logger
import time
import datetime
from utils.utils import *
from utils.training_utils import *
from utils.model_utils import *
from utils.data_utils import *
import utils.loss_utils as lu
from omegaconf import OmegaConf
import wandb
from collections import defaultdict 

def get_metrics(pred, target, round=False, name="iso", sampler_type="val"):
    all_pred = torch.cat(pred, dim=0)
    if round:
        all_pred = torch.round(all_pred)
    all_target = torch.cat(target, dim=0)
    mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    return {f"{sampler_type}_{name}_mse": mse}


def ff_evaluate(model, sampler, sampler_type="val"):
    """
    Evaluation schema: 
    - Current: Only uses iso_output
    -> Future ideas: track both ff output and iso output
    """
    
    model.eval()

    iso_preds = []
    ff_preds = []
    targets = []

    n_batches = sampler.create_batches(shuffle=False)
    for i in range(n_batches):
        (
            corpus_batch_data, 
            corpus_batch_data_node_sizes, 
            corpus_batch_data_edge_sizes, 
            batch_target, 
            corpus_batch_adj
        ) = sampler.fetch_batched_data_by_id(i)
        iso_out, ff_pred = model(sampler.packed_query_graphs,  
            sampler.query_graph_node_sizes, 
            sampler.query_graph_edge_sizes, 
            sampler.query_adj_list, 
            corpus_batch_data,
            corpus_batch_data_node_sizes, 
            corpus_batch_data_edge_sizes,
            corpus_batch_adj)
        iso_pred = ((iso_out[:,:-1] - iso_out[:,1:]) > model.delta).long().argmax(-1)+ 2
        iso_preds.append(iso_pred.data)
        ff_preds.append(ff_pred.data)
        targets.append(batch_target)


    return get_metrics(iso_preds, targets, name="iso", sampler_type=sampler_type),\
           get_metrics(ff_preds, targets, round=False, name="ff-float", sampler_type=sampler_type),\
           get_metrics(ff_preds, targets, round=True, name="ff-int", sampler_type=sampler_type)






def train():
    train_data = CliqueDataset(conf, "train", logger.info)
    val_data = CliqueDataset(conf, "val", logger.info)
    if conf.training.es_type == "ISO" or conf.training.es_type == "FF":
        es = EarlyStoppingModule(conf.base_dir, conf.task.name, patience=conf.training.patience, logger=logger)
    elif conf.training.es_type == "DUAL":
        es = DualEarlyStoppingModuleWithIsoStabilizationFollowedByFF(conf.base_dir, conf.task.name, patience=conf.training.patience, logger=logger)
    logger.info(es.__dict__)
    logger.info(f"This uses the {conf.model.classPath}.{conf.model.name} model")
    model = get_class(f"{conf.model.classPath}.{conf.model.name}")(conf, gmn_config).to(conf.training.device)
    logger.info(model)
    if conf.training.wandb_watch:
        wandb.watch(model, log_freq=1)
    logger.info(f"no. of params in model: {sum([p.numel() for p in model.parameters()])}")
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=conf.training.learning_rate,
        weight_decay=conf.training.weight_decay,
    )
    if conf.training.overwrite:
        checkpoint = None
    else:
        checkpoint = es.load_latest_model()

    if not checkpoint:
        es.save_initial_model(model)
        run = 0
    else:
        if es.should_stop_now:
            logger.info("Training has been completed. This logfile can be deleted.")
            return
        else:
            model.load_state_dict(checkpoint["model_state_dict"])
            optimizer.load_state_dict(checkpoint["optim_state_dict"])
            run = checkpoint["epoch"] + 1

    best_iso_val_mse = 1e5
    best_ff_float_val_mse = 1e5
    best_ff_int_val_mse = 1e5

    assert (conf.training.loss_fn.startswith('loss') or conf.training.loss_fn == 'loss_iso_only' or conf.training.loss_fn == 'loss_mse_only'), conf.training.loss_fn 

    loss_fn = getattr(lu, conf.training.loss_fn)

    while conf.training.run_till_early_stopping and run < conf.training.num_epochs:
        model.train()
        start_time = time.time()
        n_batches = train_data.create_batches(shuffle=True)
        epoch_loss = 0
        loss_components_dict = defaultdict(int)

        start_time = time.time()
        for i in range(n_batches):
            (   
                corpus_batch_data, 
                corpus_batch_data_node_sizes, 
                corpus_batch_data_edge_sizes, 
                batch_target, 
                corpus_batch_adj
            ) = train_data.fetch_batched_data_by_id(i)
            optimizer.zero_grad()
            iso_out, ff_out = model(train_data.packed_query_graphs, 
                        train_data.query_graph_node_sizes, 
                        train_data.query_graph_edge_sizes, 
                        train_data.query_adj_list, 
                        corpus_batch_data, 
                        corpus_batch_data_node_sizes, 
                        corpus_batch_data_edge_sizes,
                        corpus_batch_adj)
            losses, loss_components = loss_fn(conf, batch_target, iso_out, ff_out, train_data.max_query_template_size)
            losses.backward()
            optimizer.step()
            epoch_loss = epoch_loss + losses.item()
            for k,v in loss_components.items():
                loss_components_dict[k] += v

        logger.info(
            f"Run: {run} train loss: {epoch_loss/n_batches:.2f} Time: {time.time()-start_time:.2f}",
        )
        start_time = time.time()

        iso_val_dict, ff_float_val_dict, ff_int_val_dict = ff_evaluate(model, val_data, "val")

        if conf.training.es_type == "ISO" or conf.training.es_type == "FF":
            logger.info(f"Run: {run} VAL | iso_mse: {iso_val_dict['val_iso_mse']:.6f} ff-float_mse: {ff_float_val_dict['val_ff-float_mse']:.6f} ff-int_mse: {ff_int_val_dict['val_ff-int_mse']:.6f}  Time: {time.time()-start_time:.6f}")
        elif conf.training.es_type == "DUAL":
            logger.info(f"Run (Stabilized: {es.has_stabilized}): {run} VAL | iso_mse: {iso_val_dict['val_iso_mse']:.6f} ff-float_mse: {ff_float_val_dict['val_ff-float_mse']:.6f} ff-int_mse: {ff_int_val_dict['val_ff-int_mse']:.6f}  Time: {time.time()-start_time:.6f}")

        if conf.training.es_type == "ISO": 
            es_score = -iso_val_dict["val_iso_mse"]
            stop_bool =  es.check([es_score], model, run, optimizer)
        elif conf.training.es_type == "FF":
            es_score = -ff_float_val_dict["val_ff-float_mse"]
            stop_bool =  es.check([es_score], model, run, optimizer)
        elif conf.training.es_type == "DUAL":
            es_iso_score = -iso_val_dict["val_iso_mse"]
            es_ff_score = -ff_float_val_dict["val_ff-float_mse"]
            stop_bool =  es.dual_check([es_iso_score], [es_ff_score], model, run, optimizer)

        #NOTE: reusing es.num_bad_epochs variable which is reset to 0 every time best model is updated.
        if es.num_bad_epochs == 0:
            best_iso_val_mse = iso_val_dict["val_iso_mse"]
            best_ff_float_val_mse = ff_float_val_dict["val_ff-float_mse"]
            best_ff_int_val_mse = ff_int_val_dict["val_ff-int_mse"]

        if conf.training.es_type == "ISO" or conf.training.es_type == "FF":
            logger.info(f"Run: {run} best_iso_val_mse: {best_iso_val_mse:.6f} best_ff_float_val_mse: {best_ff_float_val_mse:.6f} best_ff_int_val_mse: {best_ff_int_val_mse:.6f}")
        elif conf.training.es_type == "DUAL":
            logger.info(f"Run (Stabilized: {es.has_stabilized}): {run} best_iso_val_mse: {best_iso_val_mse:.6f} best_ff_float_val_mse: {best_ff_float_val_mse:.6f} best_ff_int_val_mse: {best_ff_int_val_mse:.6f}")

        for k,v in loss_components_dict.items():
            loss_components_dict[k] =loss_components_dict[k]/n_batches

        log_dict = {'train_loss': epoch_loss/n_batches,
                    'best_iso_val_mse': best_iso_val_mse,
                    'best_ff_float_val_mse': best_ff_float_val_mse,
                    'best_ff_int_val_mse': best_ff_int_val_mse,
                    'iso_val_mse': iso_val_dict['val_iso_mse'],
                    'ff_float_val_mse': ff_float_val_dict['val_ff-float_mse'],
                    'ff_int_val_mse': ff_int_val_dict['val_ff-int_mse'],
                    }
        log_dict.update (loss_components_dict)
        wandb.log(log_dict)

        if conf.training.run_till_early_stopping:
            if stop_bool:
                break
        run += 1
        
    ckpt = es.load_best_model()
    model.load_state_dict(ckpt["model_state_dict"])
    test_data = CliqueDataset(conf, "test", logger.info)
    iso_val_dict, ff_float_val_dict, ff_int_val_dict = ff_evaluate(model, val_data, "val")
    logger.info(f"VAL || iso: {iso_val_dict} ff_float: {ff_float_val_dict} ff_int: {ff_int_val_dict}")
    iso_test_dict, ff_float_test_dict, ff_int_test_dict = ff_evaluate(model, test_data, "test")
    combined_dict = {}
    combined_dict.update(iso_test_dict)
    combined_dict.update(ff_float_test_dict)
    combined_dict.update(ff_int_test_dict)
    wandb.log(combined_dict)
    logger.info(combined_dict)
    
    


if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    print(OmegaConf.to_yaml(conf))
    
    name_removal_set = {'classPath', 'name', 'sinkhorn_num_iters', 'mask_sinkhorn'}

    # Our Model. No meaning of EQ here
    assert not conf.model.EQ 

    if conf.training.es_type == "ISO":
        es_name = "ISO"
    elif conf.training.es_type == "FF":
        es_name = "FF"
    elif conf.training.es_type == "DUAL":
        es_name = "DUAL"
    else: raise NotImplementedError()
    
    task_name = es_name+",".join("{}={}".format(greek_letter_to_unicode_map(i), k ) for i, k  in conf.model.items() if (i not in name_removal_set))
    print(task_name)

    conf.task.name = f"{conf.model.name}_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"

    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.wandb_group,
        config={
            'learning_rate': conf.training.learning_rate,
            'weight_decay': conf.training.weight_decay,
            'dropout': conf.training.dropout,
            'num_epochs': conf.training.num_epochs,
            'seed': conf.training.seed,
            'batch_size': conf.training.batch_size,
            'dataset_name': conf.dataset.name,
            'dataset_max_node_set_size': conf.dataset.max_node_set_size,
            'dataset_max_edge_set_size': conf.dataset.max_edge_set_size,
            'model_name': conf.model.name,
            'gamma': conf.model.gamma,
            'delta': conf.model.delta,
            'LAMBDA': conf.model.LAMBDA,
            'LAMBDA2': conf.model.LAMBDA2,
            'es_type': conf.training.es_type,
        }
    )


    set_seed(conf.training.seed)
    gmn_config = modify_gmn_main_config(get_default_gmn_config(conf), conf, logger)
    train()
