import torch
from loguru import logger
import time
import datetime
from utils.utils import *
from utils.training_utils import *
from utils.model_utils import *
from utils.data_utils import *
import utils.loss_utils as lu
from omegaconf import OmegaConf
import wandb


from models.DIFUSCO_baselines import DIFUSCODataset

import tqdm

def baseline_evaluate(model, sampler, is_test=False):
    model.eval()

    pred_list = []
    non_decoder_pred_list = [] 
    targets = []
    if is_test:
        sample_loader = test_dataloader(model.args, sampler)
    else: 
        sample_loader = val_dataloader(model.args, sampler)
    for batch_idx, batch_data in tqdm.tqdm(enumerate(sample_loader)):
        batch_data = [x.to(model.device) for x in batch_data]
        pred, non_decoder_pred = model.test_step(batch_data, batch_idx)
        pred_list.append(pred)
        non_decoder_pred_list.append(non_decoder_pred)
        targets.append(batch_data[1].x.sum().item())
    all_pred = torch.tensor(pred_list, device=model.device)
    all_non_decoder_pred = torch.tensor(non_decoder_pred_list, device=model.device)
    all_target = torch.tensor(targets, device=model.device)
    mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    ratio = (torch.round(all_pred) / all_target).mean()
    non_decoder_mse = torch.nn.functional.mse_loss(all_target, all_non_decoder_pred, reduction="mean").item()
    non_decoder_ratio = (torch.round(all_non_decoder_pred) / all_target).mean()
    return  mse, ratio, non_decoder_mse, non_decoder_ratio


from torch_geometric.data import DataLoader as GraphDataLoader

def train_dataloader(args,dataset):
    batch_size = args.batch_size
    train_dataloader = GraphDataLoader(
        dataset, batch_size=batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True,
        persistent_workers=True, drop_last=True)
    return train_dataloader

def test_dataloader(args, dataset):
    batch_size = 1
    test_dataloader = GraphDataLoader(dataset, batch_size=batch_size, shuffle=False)
    return test_dataloader

def val_dataloader(args,dataset):
    batch_size = 1
    val_dataset = torch.utils.data.Subset(dataset, range(args.validation_examples))
    val_dataloader = GraphDataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return val_dataloader

def train():
    train_data = DIFUSCODataset(conf, "train", logger.info)
    val_data = DIFUSCODataset(conf, "val", logger.info)
    es = EarlyStoppingModule(conf.base_dir, conf.task.name, logger=logger)
    logger.info(es.__dict__)
    logger.info(f"This uses the {conf.model.classPath}.{conf.model.name} model")
    logger.info(conf)
    model = get_class(f"{conf.model.classPath}.{conf.model.name}")(conf).to(conf.training.device)
    logger.info(model)
    if conf.training.wandb_watch:
        wandb.watch(model, log_freq=1)
    logger.info(f"no. of params in model: {sum([p.numel() for p in model.parameters()])}")
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=conf.training.learning_rate,
        weight_decay=conf.training.weight_decay,
    )
    if conf.training.overwrite:
        checkpoint = None
    else:
        checkpoint = es.load_latest_model()

    if not checkpoint:
        es.save_initial_model(model)
        run = 0
    else:
        if es.should_stop_now:
            logger.info("Training has been completed. This logfile can be deleted.")
            return
        else:
            model.load_state_dict(checkpoint["model_state_dict"])
            optimizer.load_state_dict(checkpoint["optim_state_dict"])
            run = checkpoint["epoch"] + 1

    best_val_mse = 1e5
    best_val_ratio = -1


    while conf.training.run_till_early_stopping and run < conf.training.num_epochs:
        model.train()
        start_time = time.time()
        tr_loader = train_dataloader(model.args, train_data)
        epoch_loss = 0

        start_time = time.time()
        for batch_idx, batch_data  in enumerate(tr_loader):
            batch_data = [x.to(model.device
                               ) for x in batch_data]
            optimizer.zero_grad()
            losses = model(batch_data, batch_idx)
          
            losses.backward()
            optimizer.step()
            epoch_loss = epoch_loss + losses.item()

        n_batches = batch_idx + 1 
        logger.info(
            f"Run: {run} train loss: {epoch_loss/n_batches:.2f} Time: {time.time()-start_time:.2f}",
        )
        start_time = time.time()
        decoder_mse, decoder_ratio, non_decoder_mse, non_decoder_ratio = baseline_evaluate(model, val_data)
        
        if conf.model.SUP:
            mse = non_decoder_mse
            ratio = non_decoder_ratio
        else:
            mse = decoder_mse
            ratio = decoder_ratio
        
        if mse < best_val_mse:
            best_val_mse = mse
            best_val_ratio = ratio

        logger.info(f"Run: {run} VAL mse: {mse:.6f}  ratio: {ratio:.6f} Time: {time.time()-start_time:.6f}")
        
        log_dict = {'val_mse': mse,
                    'best_val_mse': best_val_mse,
                    'best_val_ratio': best_val_ratio,
                    'train_loss': epoch_loss/n_batches
                    }      
        wandb.log(log_dict)

        if conf.training.run_till_early_stopping:
            es_score = -mse
            if es.check([es_score], model, run, optimizer):
                break
            if mse==0:
                break
            run += 1
        
    ckpt = es.load_best_model()
    model.load_state_dict(ckpt["model_state_dict"])
    test_data = DIFUSCODataset(conf, "test", logger.info)
    
    decoder_mse, decoder_ratio, non_decoder_mse, non_decoder_ratio = baseline_evaluate(model, val_data)

    if conf.model.SUP:
        mse = non_decoder_mse
        ratio = non_decoder_ratio
    else:
        mse = decoder_mse
        ratio = decoder_ratio

    logger.info(f"Run: {run} VAL mse: {mse:.6f} ratio: {ratio:.6f} Time: {time.time()-start_time:.6f}")
    decoder_mse, decoder_ratio, non_decoder_mse, non_decoder_ratio = baseline_evaluate(model, test_data, is_test=True)

    if conf.model.SUP:
        mse = non_decoder_mse
        ratio = non_decoder_ratio
    else:
        mse = decoder_mse
        ratio = decoder_ratio
        
    logger.info(f"Run: {run} TEST mse: {mse:.6f} ratio: {ratio:.6f} Time: {time.time()-start_time:.6f}")
    wandb.log({'test_mse_loss': mse, 'test_ratio': ratio})



if __name__ == "__main__":
    main_conf = OmegaConf.load("configs/config.yaml")
    cli_conf = OmegaConf.from_cli()
    data_conf = OmegaConf.load(f"configs/data_configs/{cli_conf.dataset.name}.yaml")
    model_conf = OmegaConf.load(f"configs/model_configs/{cli_conf.model.name}.yaml")
    conf = OmegaConf.merge(main_conf, data_conf, model_conf, cli_conf)
    conf = OmegaConf.merge(conf, cli_conf)
    run_time = "{date:%Y-%m-%d||%H:%M:%S}".format(date=datetime.datetime.now())
    print(OmegaConf.to_yaml(conf))
    
    task_name = ",".join("{}={}".format(*i) for i in conf.model.items() if (i[0] != 'classPath' and i[0] != 'name' and i[0] != 'EQ'))

    if conf.model.EQ:
        conf.task.name = f"{conf.model.name}_EQ_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"
    else:
        conf.task.name = f"{conf.model.name}_{conf.dataset.name}_{task_name}_{conf.training.loss_fn}_{run_time}"
    open(f"{conf.log.dir}/{conf.task.name}.log", "w").close()  # Clear log file
    logger.add(f"{conf.log.dir}/{conf.task.name}.log")
    logger.info(OmegaConf.to_yaml(conf))
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    wandb.init(
        project=conf.task.wandb_project,
        name=conf.task.name,
        group=conf.task.wandb_group,
        config={
            'learning_rate': conf.training.learning_rate,
            'weight_decay': conf.training.weight_decay,
            'dropout': conf.training.dropout,
            'num_epochs': conf.training.num_epochs,
            'seed': conf.training.seed,
            'batch_size': conf.training.batch_size,
            'dataset_name': conf.dataset.name,
            'dataset_max_node_set_size': conf.dataset.max_node_set_size,
            'dataset_max_edge_set_size': conf.dataset.max_edge_set_size,
            'model_name': conf.model.name,
        }
    )


    set_seed(conf.training.seed, conf)
    train()
