import torch
from loguru import logger
import time
import datetime
from utils.utils import *
from utils.training_utils import *
from utils.model_utils import *
from utils.data_utils import *
import utils.loss_utils as lu
from omegaconf import OmegaConf
import wandb
from collections import defaultdict
from models.GFNET_baseline import get_data_loaders, get_test_data_loaders
from GFlowNet_CombOpt.gflownet.main import get_logr_scaler, rollout
from GFlowNet_CombOpt.gflownet.util import MaxCliqueMDP
import dgl
from einops import rearrange


@torch.no_grad()
def evaluate(ep, train_step, train_data_used, logr_scaler, loader, targets, model, gfc):
    torch.cuda.empty_cache()
    num_repeat = gfc.num_repeat  #20
    mis_ls, mis_top20_ls = [], []
    logr_ls = []
    pbar = enumerate(loader)
    pred_list = []
    # pbar.set_description(f"Test Epoch {ep:2d} Data used {train_data_used:5d}")
    for batch_idx, gbatch in pbar:
        gbatch = gbatch.to(gfc.device)
        gbatch_rep = dgl.batch([gbatch] * num_repeat)

        env = MaxCliqueMDP(gbatch_rep, gfc)
        state = env.state
        while not all(env.done):
            action = model.alg.sample(gbatch_rep, state, env.done, rand_prob=0.)
            state = env.step(action)
        logr_rep = logr_scaler(env.get_log_reward())
        logr_ls += logr_rep.tolist()
        curr_mis_rep = torch.tensor(env.batch_metric(state))
        curr_mis_rep = rearrange(curr_mis_rep, "(rep b) -> b rep", rep=num_repeat).float()
        mis_ls += curr_mis_rep.mean(dim=1).tolist()
        mis_top20_ls += curr_mis_rep.max(dim=1)[0].tolist()
        


    all_pred = torch.tensor(mis_top20_ls, device=model.device)
    all_target = torch.tensor(targets, device=model.device)    

    mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    ratio = (torch.round(all_pred) / all_target).mean()
    
    return  mse, ratio

    



def train():
    LATEST_MODEL_PATH = 'latestModels'
    BEST_MODEL_PATH = 'bestValidationModels'
    logger.info(f"This uses the {conf.model.classPath}.{conf.model.name} model")
    model = get_class(f"{conf.model.classPath}.{conf.model.name}")(conf).to(conf.training.device)
    logger.info(conf)
    logger.info(model)
    logger.info(f"no. of params in model: {sum([p.numel() for p in model.alg.parameters()])}")
    gfnet_conf = model.conf

    logger.info(f"GFNET config: {gfnet_conf}")
    train_loader, val_loader = get_data_loaders(gfnet_conf, conf)
    val_gt = pickle.load(open(f'{conf.dataset.path}/{conf.dataset.name}/truth/val.pkl',"rb"))


    best_val_mse = 1e5
    best_val_ratio = -1

    train_data_used = 0
    train_logr_scaled_ls = []
    train_metric_ls = []
    train_step = 0
    run = 0

    while conf.training.run_till_early_stopping and run < gfnet_conf.epochs:
        start_time = time.time()
        epoch_loss = 0
        for batch_idx, gbatch in enumerate(train_loader):
            reward_exp = None
            process_ratio = max(0., min(1., train_data_used / gfnet_conf.annend))
            logr_scaler = get_logr_scaler(gfnet_conf, process_ratio=process_ratio, reward_exp=reward_exp)
            
            train_logr_scaled_ls = train_logr_scaled_ls[-5000:]
            train_metric_ls = train_metric_ls[-5000:]
            gbatch = gbatch.to(gfnet_conf.device)
            if gfnet_conf.same_graph_across_batch:
                gbatch = dgl.batch([gbatch] * gfnet_conf.batch_size_interact)
            train_data_used += gbatch.batch_size

            ###### rollout
            batch, metric_ls = rollout(gbatch, gfnet_conf, model.alg)
            model.buffer.add_batch(batch)

            logr = logr_scaler(batch[-2][:, -1])
            train_logr_scaled_ls += logr.tolist()
            train_logr_scaled = logr.mean().item()
            train_metric_ls += metric_ls
            train_traj_len = batch[-1].float().mean().item()

            ##### train
            batch_size = min(len(model.buffer), gfnet_conf.batch_size)
            indices = list(range(len(model.buffer)))
            for _ in range(gfnet_conf.tstep):
                if len(indices) == 0:
                    break
                curr_indices = random.sample(indices, min(len(indices), batch_size))
                batch = model.buffer.sample_from_indices(curr_indices)
                train_info = model.alg.train_step(*batch, reward_exp=reward_exp, logr_scaler=logr_scaler)
                epoch_loss += train_info["train/loss"]
                indices = [i for i in indices if i not in curr_indices]

                       
            if gfnet_conf.onpolicy:
                model.buffer.reset()    
        
    
            if train_step % gfnet_conf.print_freq == 0:
                logger.info(f"Epoch {run:2d} Train_Step {train_step} Batch_idx {batch_idx} Data used {train_data_used:.3e}: loss={train_info['train/loss']:.2e}, "
                      + (f"LogZ={train_info['train/logZ']:.2e}, " if gfnet_conf.alg in ["tb", "tbbw"] else "")
                      + f"metric size={np.mean(train_metric_ls):.2f}+-{np.std(train_metric_ls):.2f}, "
                      + f"LogR scaled={train_logr_scaled:.2e} traj_len={train_traj_len:.2f}")

            train_step += 1

            ##### eval
            if batch_idx == 0 or train_step % gfnet_conf.eval_freq == 0:
                
                # saving latest model 
                torch.save({
                    'model_state_dict': model.alg.model.state_dict(),
                    'optimizer_state_dict': model.alg.optimizer.state_dict(),
                    'epoch': run,
                }, f'{LATEST_MODEL_PATH}/{conf.task.name}.pt')
                
                    
                mse, ratio = evaluate(run, train_step, train_data_used, logr_scaler, val_loader,val_gt, model, gfnet_conf)
                
                if best_val_mse - mse > 1e-5:
                    best_val_mse = mse
                    best_val_ratio = ratio
                    torch.save({
                        'model_state_dict': model.alg.model.state_dict(),
                        'optimizer_state_dict': model.alg.optimizer.state_dict(),
                        'model_flow_state_dict': model.alg.model_flow.state_dict(),
                        'epoch': run
                    }, f'{BEST_MODEL_PATH}/{conf.task.name}.pt')
                    logger.info(f"Saving best model at {BEST_MODEL_PATH}/{conf.task.name}.pt")

                logger.info(f"Run: {run} VAL mse (best): {mse:.6f} ({best_val_mse:.6f}) ratio (best): {ratio:.6f} ({best_val_ratio:.6f})")
                    
                    
        log_dict = {'val_mse': mse,
            'best_val_mse': best_val_mse,
            'best_val_ratio': best_val_ratio,
            'train_loss': epoch_loss/len(train_loader)
            }       
        
        wandb.log(log_dict) 
        run += 1
            
    # Testing code
    test_loader = get_test_data_loaders(gfnet_conf, conf)
    test_gt = pickle.load(open(f'{conf.dataset.path}/{conf.dataset.name}/truth/test.pkl',"rb"))
    
    # loading best model
    save_dict = torch.load(f'{BEST_MODEL_PATH}/{conf.task.name}.pt')
    model.alg.model.load_state_dict(save_dict["model_state_dict"])
    model.alg.model_flow.load_state_dict(save_dict["model_flow_state_dict"])


    mse, ratio = evaluate(run, train_step, train_data_used, logr_scaler, test_loader, test_gt, model, gfnet_conf)
    wandb.log({'test_mse_loss': mse, 'test_ratio': ratio})
    logger.info(f"Run: {run} TEST mse: {mse:.6f} ratio: {ratio:.6f}")
    
    




if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    print(OmegaConf.to_yaml(conf))
    
    task_name = ",".join("{}={}".format(*i) for i in conf.model.items() if (i[0] != 'classPath' and i[0] != 'name' and i[0] != 'EQ'))

    if conf.model.EQ:
        conf.task.name = f"{conf.model.name}_EQ_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"
    else:
        conf.task.name = f"{conf.model.name}_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"

    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.wandb_group,
        config={
            'learning_rate': conf.training.learning_rate,
            'weight_decay': conf.training.weight_decay,
            'dropout': conf.training.dropout,
            'num_epochs': conf.training.num_epochs,
            'seed': conf.training.seed,
            'batch_size': conf.training.batch_size,
            'dataset_name': conf.dataset.name,
            'dataset_max_node_set_size': conf.dataset.max_node_set_size,
            'dataset_max_edge_set_size': conf.dataset.max_edge_set_size,
            'model_name': conf.model.name,
        }
    )


    set_seed(conf.training.seed, conf)
    train()
