import torch
from loguru import logger
import time
import datetime
from utils.utils import *
from utils.training_utils import *
from utils.model_utils import *
from utils.data_utils import *
import utils.loss_utils as lu
from omegaconf import OmegaConf
import wandb
from collections import defaultdict
from models.GFNET_baseline import get_data_loaders, get_test_data_loaders
from GFlowNet_CombOpt.gflownet.main import get_logr_scaler, rollout
from GFlowNet_CombOpt.gflownet.util import MaxCliqueMDP
import dgl
from einops import rearrange

CKPT_PATH = 'best_weights/GFNET'
all_ckpts = os.listdir(CKPT_PATH)

def test():

    model = get_class(f"{conf.model.classPath}.{conf.model.name}")(conf).to(conf.training.device)
    gfnet_conf = model.conf 
    test_loader = get_test_data_loaders(gfnet_conf, conf)
    test_gt = pickle.load(open(f'{conf.dataset.path}/{conf.dataset.name}/truth/test.pkl',"rb"))
    
    # loading best model

    ckpt = list(filter(lambda x: conf.dataset.name in x, all_ckpts))
    assert len(ckpt) == 1, f"Multiple or no checkpoints found for {conf.dataset.name}"
    ckpt = ckpt[0]
    
    save_dict = torch.load(os.path.join(CKPT_PATH, ckpt))
    model.alg.model.load_state_dict(save_dict["model_state_dict"])
    model.alg.model_flow.load_state_dict(save_dict["model_flow_state_dict"])

    logr_scaler = get_logr_scaler(gfnet_conf, process_ratio=1, reward_exp=None)


    std, mse, ratio = evaluate(None, None, None, logr_scaler, test_loader, test_gt, model, gfnet_conf)
    logger.info(f"TEST mse: {mse:.6f} ratio: {ratio:.6f} std: {std:.6f}")
    wandb.log({'test_mse_loss': mse, 'test_ratio': ratio, 'test_std': std})




@torch.no_grad()
def evaluate(ep, train_step, train_data_used, logr_scaler, loader, targets, model, gfc):
    torch.cuda.empty_cache()
    num_repeat = gfc.num_repeat  #20
    logger.info(f'Num repeat: {num_repeat}')
    mis_ls, mis_top20_ls = [], []
    logr_ls = []
    pbar = enumerate(loader)
    pred_list = []
    # pbar.set_description(f"Test Epoch {ep:2d} Data used {train_data_used:5d}")
    for batch_idx, gbatch in pbar:
        gbatch = gbatch.to(gfc.device)
        gbatch_rep = dgl.batch([gbatch] * num_repeat)

        env = MaxCliqueMDP(gbatch_rep, gfc)
        state = env.state
        # logger.info(f'Batch Idx during evaluation = {batch_idx}')
        # num_steps = 0
        while not all(env.done):
            # num_steps += 1
            action = model.alg.sample(gbatch_rep, state, env.done, rand_prob=0.)
            state = env.step(action)
        # logger.info(f'Num steps during eval = {num_steps}')
        logr_rep = logr_scaler(env.get_log_reward())
        logr_ls += logr_rep.tolist()
        curr_mis_rep = torch.tensor(env.batch_metric(state))
        curr_mis_rep = rearrange(curr_mis_rep, "(rep b) -> b rep", rep=num_repeat).float()
        mis_ls += curr_mis_rep.mean(dim=1).tolist()
        mis_top20_ls += curr_mis_rep.max(dim=1)[0].tolist()
        # logger.info(f'Clique size found = {curr_mis_rep.max(dim=1)[0]}')


        # pbar.set_postfix({"Metric": f"{np.mean(mis_ls):.2f}+-{np.std(mis_ls):.2f}"})
        # pred_list.extend(mis_top20_ls)
        


    all_pred = torch.tensor(mis_top20_ls, device=model.device)
    all_target = torch.tensor(targets, device=model.device)


    mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    ratio = (torch.round(all_pred) / all_target).mean()
    sum_p_std = (((all_target - all_pred)**2).std() / np.sqrt(len(all_target))).item()


    
    return  sum_p_std, mse, ratio

    



if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    
    task_name = ",".join("{}={}".format(*i) for i in conf.model.items() if (i[0] != 'classPath' and i[0] != 'name' and i[0] != 'EQ'))

    if conf.model.EQ:
        conf.task.name = f"{conf.model.name}_EQ_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"
    else:
        conf.task.name = f"{conf.model.name}_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"

    
    conf.log.dir = 'test_logs'
    os.makedirs(conf.log.dir, exist_ok=True)

    # **
    decoder_steps = conf.model.num_repeat

    conf.task.name = f"{conf.model.name}_{conf.dataset.name}_decoder_steps={decoder_steps}"
    print(OmegaConf.to_yaml(conf))

    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.name,
        config={
            'learning_rate': conf.training.learning_rate,
            'weight_decay': conf.training.weight_decay,
            'dropout': conf.training.dropout,
            'num_epochs': conf.training.num_epochs,
            'seed': conf.training.seed,
            'batch_size': conf.training.batch_size,
            'dataset_name': conf.dataset.name,
            'dataset_max_node_set_size': conf.dataset.max_node_set_size,
            'dataset_max_edge_set_size': conf.dataset.max_edge_set_size,
            'model_name': conf.model.name,
        }
    )


    set_seed(conf.training.seed, conf)
    test()
