
import torch
from loguru import logger
import time
import datetime
from utils.utils import *
from utils.training_utils import *
from utils.model_utils import *
from utils.data_utils import *
import utils.loss_utils as lu
from omegaconf import OmegaConf
import wandb
from collections import defaultdict 

def get_metrics(pred, target, round=False, name="iso", sampler_type="val"):
    all_pred = torch.cat(pred, dim=0)
    if round:
        all_pred = torch.round(all_pred)
    all_target = torch.cat(target, dim=0)
    mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    sum_p_std = (((all_target - all_pred)**2).std() / np.sqrt(len(all_target))).item()
    return sum_p_std, mse



def test():
    test_data = CliqueDataset(conf, "test", logger.info)
    run = 'testing'

    ckpt = list(filter(lambda x: conf.dataset.name in x, all_ckpts))
    assert len(ckpt) == 1, f"Multiple or no checkpoints found for {conf.dataset.name}"
    ckpt = ckpt[0]
    import re
    match = re.search(r'τ=([0-9]*\.?[0-9]+)', ckpt)
    tau = float(match.group(1))
    conf.model.sinkhorn_temp = tau
    logger.info(f"Using tau: {tau}")

    model = get_class(f"{conf.model.classPath}.{conf.model.name}")(conf, gmn_config).to(conf.training.device)
    model.load_state_dict(torch.load(os.path.join(CKPT_PATH, ckpt))['model_state_dict'])
    start_time = time.time()

    (iso_std, iso_mse), (int_std, int_mse), (float_std, float_mse) = ff_evaluate(model, test_data, "test")

   
    logger.info(f"TEST | iso_mse: {iso_mse:.6f} ff-float_mse: {float_mse:.6f} ff-int_mse: {int_mse:.6f}  iso_std: {iso_std:.6f} ff-float_std: {float_std:.6f} ff-int_std: {int_std:.6f} Time: {time.time()-start_time:.6f}")
    wandb.log({'test_iso_mse_loss': iso_mse, 'test_ff-float_mse_loss': float_mse, 'test_ff-int_mse_loss': int_mse, 'test_iso_std': iso_std, 'test_ff-float_std': float_std, 'test_ff-int_std': int_std})



def ff_evaluate(model, sampler, sampler_type="val"):
    """
    Evaluation schema: 
    - Current: Only uses iso_output
    -> Future ideas: track both ff output and iso output
    """
    
    model.eval()

    iso_preds = []
    ff_preds = []
    targets = []

    n_batches = sampler.create_batches(shuffle=False)
    for i in range(n_batches):
        (
            corpus_batch_data, 
            corpus_batch_data_node_sizes, 
            corpus_batch_data_edge_sizes, 
            batch_target, 
            corpus_batch_adj
        ) = sampler.fetch_batched_data_by_id(i)
        iso_out, ff_pred = model(sampler.packed_query_graphs,  
            sampler.query_graph_node_sizes, 
            sampler.query_graph_edge_sizes, 
            sampler.query_adj_list, 
            corpus_batch_data,
            corpus_batch_data_node_sizes, 
            corpus_batch_data_edge_sizes,
            corpus_batch_adj)

        iso_pred = ((iso_out[:,:-1] - iso_out[:,1:]) > model.delta).long().argmax(-1)+ 2

        iso_preds.append(iso_pred.data)
        ff_preds.append(ff_pred.data)
        targets.append(batch_target)

    return    get_metrics(iso_preds, targets, name="iso", sampler_type=sampler_type),\
              get_metrics(ff_preds, targets, round=True, name="ff-int", sampler_type=sampler_type),\
              get_metrics(ff_preds, targets, round=False, name="ff-float", sampler_type=sampler_type)







if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    print(OmegaConf.to_yaml(conf))
    
    name_removal_set = {'classPath', 'name', 'sinkhorn_num_iters', 'mask_sinkhorn'}
    # Our Model. No meaning of EQ here
    assert not conf.model.EQ 
    
    def greek_letter_to_unicode_map(letter):
        # print(letter)
        if letter == 'sinkhorn_temp':
            return '\u03c4'
        if letter == 'gamma':
            return '\u03b3'
        elif letter == 'delta':
            return '\u03b4'
        elif letter == 'LAMBDA':
            return '\u039b'
        elif letter == 'LAMBDA2':
            return '\u039b2'
        else:
            return letter

    if conf.training.es_type == "ISO":
        es_name = "ISO"
    elif conf.training.es_type == "FF":
        es_name = "FF"
    elif conf.training.es_type == "DUAL":
        es_name = "DUAL"
    else: raise NotImplementedError()
    
    conf.log.dir = 'our_test_logs'
    os.makedirs(conf.log.dir, exist_ok=True)

    # **

    conf.task.name = f"{conf.model.name}_{es_name}_{conf.dataset.name}"

    print(OmegaConf.to_yaml(conf))


    CKPT_PATH = f'our_best_models/{es_name}'
    all_ckpts = os.listdir(CKPT_PATH)

    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.name,
        config={

            'learning_rate': conf.training.learning_rate,
            'weight_decay': conf.training.weight_decay,
            'dropout': conf.training.dropout,
            'num_epochs': conf.training.num_epochs,
            'seed': conf.training.seed,
            'batch_size': conf.training.batch_size,
            'dataset_name': conf.dataset.name,
            'dataset_max_node_set_size': conf.dataset.max_node_set_size,
            'dataset_max_edge_set_size': conf.dataset.max_edge_set_size,
            'model_name': conf.model.name,
            'gamma': conf.model.gamma,
            'delta': conf.model.delta,
            'LAMBDA': conf.model.LAMBDA,
            'LAMBDA2': conf.model.LAMBDA2,
            'es_type': conf.training.es_type,
        }
    )


    set_seed(conf.training.seed)
    gmn_config = modify_gmn_main_config(get_default_gmn_config(conf), conf, logger)
    test()
