import torch
from loguru import logger
import time
import datetime
from utils.utils import *
from utils.training_utils import *
from utils.model_utils import *
from utils.data_utils import *
import utils.loss_utils as lu
from omegaconf import OmegaConf
import wandb
from collections import defaultdict 
import os

CKPT_PATH = 'best_weights/SCAT'
all_ckpts = os.listdir(CKPT_PATH)

def baseline_evaluate(model, sampler):
    model.eval()

    solver_pred_list = []
    sum_p_pred_list = []
    targets = []

    n_batches = sampler.create_batches(shuffle=False)
    for i in range(n_batches):
        (
            corpus_batch_data, 
            corpus_batch_data_node_sizes, 
            corpus_batch_data_edge_sizes, 
            batch_target, 
            corpus_batch_adj
        ) = sampler.fetch_batched_data_by_id(i)
        solver_pred, sum_p_pred = model(sampler.packed_query_graphs, 
            sampler.query_graph_node_sizes, 
            sampler.query_graph_edge_sizes, 
            sampler.query_adj_list, 
            corpus_batch_data, 
            corpus_batch_data_node_sizes, 
            corpus_batch_data_edge_sizes,
            corpus_batch_adj)
        #pred_list.append(torch.stack(pred,dim=0).data)
        solver_pred_list.append(solver_pred)
        sum_p_pred_list.append(sum_p_pred)
        targets.append(batch_target)
    all_solver_pred = torch.cat(solver_pred_list, dim=0)
    all_sum_p_pred = torch.cat(sum_p_pred_list, dim=0)
    all_target = torch.cat(targets, dim=0)
    solver_mse = torch.nn.functional.mse_loss(all_target, all_solver_pred, reduction="mean").item()
    solver_mse_std = (((all_target - all_solver_pred)**2).std() / np.sqrt(len(all_target))).item()
    solver_ratio = (torch.round(all_solver_pred) / all_target).mean()
    sum_p_mse = torch.nn.functional.mse_loss(all_target, all_sum_p_pred, reduction="mean").item()
    sum_p_std = (((all_target - all_sum_p_pred)**2).std() / np.sqrt(len(all_target))).item()
    sum_p_ratio = (torch.round(all_sum_p_pred) / all_target).mean()
    return solver_mse, solver_ratio, sum_p_mse, sum_p_ratio, solver_mse_std, sum_p_std
    


def test():


    test_data = CliqueDataset(conf, "test", logger.info)
    test_data.data_type = "pyg"
    run = 'testing'

    ckpt = list(filter(lambda x: conf.dataset.name in x, all_ckpts))
    assert len(ckpt) == 1, f"Multiple or no checkpoints found for {conf.dataset.name}"
    ckpt = ckpt[0]
    model = get_class(f"{conf.model.classPath}.{conf.model.name}")(conf).to(conf.training.device)
    model.load_state_dict(torch.load(os.path.join(CKPT_PATH, ckpt))['model_state_dict'])
    start_time = time.time()
    solver_mse, solver_ratio, sum_p_mse, sum_p_ratio, solver_std, sum_p_std = baseline_evaluate(model, test_data)
    logger.info(f"Run: {run} TEST solver_mse: {solver_mse:.6f} solver_ratio: {solver_ratio:.6f} sum_p_mse: {sum_p_mse:.6f} sum_p_ratio: {sum_p_ratio:.6f} solver_std: {solver_std:.6f} sum_p_std: {sum_p_std:.6f} Time: {time.time()-start_time:.6f}")
    wandb.log({'test_solver_mse': solver_mse, 'test_solver_ratio': solver_ratio, 'test_solver_std': solver_std, \
               'test_sum_p_mse': sum_p_mse, 'test_sum_p_ratio': sum_p_ratio, 'test_sum_p_std': sum_p_std})



if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    
    
    task_name = ",".join("{}={}".format(*i) for i in conf.model.items() if (i[0] != 'classPath' and i[0] != 'name' and i[0] != 'EQ'))

    if conf.model.EQ:
        conf.task.name = f"{conf.model.name}_EQ_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"
    else:
        conf.task.name = f"{conf.model.name}_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"
    
    conf.log.dir = 'test_logs'
    os.makedirs(conf.log.dir, exist_ok=True)

    # **
    decoder_steps = conf.model.decoder_steps

    conf.task.name = f"{conf.model.name}_{conf.dataset.name}_decoder_steps={decoder_steps}"

    print(OmegaConf.to_yaml(conf))
    
    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False


    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.name,
        config={
            'learning_rate': conf.training.learning_rate,
            'weight_decay': conf.training.weight_decay,
            'dropout': conf.training.dropout,
            'num_epochs': conf.training.num_epochs,
            'seed': conf.training.seed,
            'batch_size': conf.training.batch_size,
            'dataset_name': conf.dataset.name,
            'dataset_max_node_set_size': conf.dataset.max_node_set_size,
            'dataset_max_edge_set_size': conf.dataset.max_edge_set_size,
            'model_name': conf.model.name,
        }
    )


    set_seed(conf.training.seed, conf)
    test()
