import logging
import os
import time
from typing import Any, Dict
import numpy as np
import torch
from torch_geometric.graphgym.checkpoint import (
    clean_ckpt,
    load_ckpt,
    save_ckpt,
)
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.register import register_train
from torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch
import torch.nn as nn
from tqdm import tqdm
from ..utils import add_full_rrwp, cfg_to_dict, flatten_dict, make_wandb_name

def train_epoch(logger, loader, model, optimizer, scheduler):
    model.train()
    model.to(torch.device(cfg.device))
    time_start = time.time()
    torch.cuda.synchronize()
    losses = []
    for idx, batch in enumerate(loader):

        batch.to(torch.device(cfg.device))
        batch = add_full_rrwp(batch, walk_length=cfg.posenc_RRWP.ksteps)
        optimizer.zero_grad()
        pred, true = model(batch)
        loss = reg_loss(pred=pred, true=true)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        logger.update_stats(true=true.detach().cpu(),
                            pred=pred.detach().cpu(), loss=loss.item(),
                            lr=scheduler.get_last_lr()[0],
                            time_used=time.time() - time_start,
                            params=cfg.params)
        time_start = time.time()
        pass
    scheduler.step()
    return np.mean(losses)


def reg_loss(**kwargs):
    if cfg.model.loss_fun == 'l1':
        return mae_loss(**kwargs)
    elif cfg.model.loss_fun == 'l2':
        return mse_loss(**kwargs)

def mae_loss(pred, true): 
    criterion = nn.L1Loss()
    return criterion(pred, true) 

def mse_loss(pred, true): 
    criterion = nn.MSELoss()
    return criterion(pred, true)


def eval_epoch(logger, loader, model):
    model.eval()
    time_start = time.time()
    losses  = []
    for batch in loader:
        batch.to(torch.device(cfg.device))
        batch = add_full_rrwp(batch, walk_length=cfg.posenc_RRWP.ksteps)
        pred, true = model(batch)
        loss = mae_loss(pred, true) 
        losses.append(loss.item())
        logger.update_stats(true=true.detach().cpu(),
                            pred=pred.detach().cpu(), loss=loss.item(),
                            lr=0, time_used=time.time() - time_start,
                            params=cfg.params)
        time_start = time.time()
    return np.mean(losses)


@register_train('deep_simulator')
def train_example(loggers, loaders, model, optimizer, scheduler):
    start_epoch = 0
    if cfg.train.auto_resume:
        start_epoch = load_ckpt(model, optimizer, scheduler,
                                cfg.train.epoch_resume)
    if start_epoch == cfg.optim.max_epoch:
        logging.info('Checkpoint found, Task already done')
    else:
        logging.info('Start from epoch %s', start_epoch)

    if cfg.wandb.use:
        try:
            import wandb
        except:
            raise ImportError('WandB is not installed.')
        if cfg.wandb.name == '':
            wandb_name = make_wandb_name(cfg)
        else:
            wandb_name = cfg.wandb.name
        run = wandb.init(entity=cfg.wandb.entity, project=cfg.wandb.project,
                         name=wandb_name)
        run.config.update(cfg_to_dict(cfg))

    num_splits = len(loggers)
    perf = [[] for _ in range(num_splits)]
    full_epoch_times = []

    with tqdm(range(start_epoch, cfg.optim.max_epoch), desc="Training") as pbar:
        for cur_epoch in pbar:
            start_time = time.perf_counter()
            train_loss = train_epoch(loggers[0], loaders[0], model, optimizer, scheduler)
            perf[0].append(loggers[0].write_epoch(cur_epoch))

            if is_eval_epoch(cur_epoch):
                for i in range(1, num_splits):
                    eval_loss = eval_epoch(loggers[i], loaders[i], model)
                    perf[i].append(loggers[i].write_epoch(cur_epoch))
            val_perf = perf[1]
            full_epoch_times.append(time.perf_counter() - start_time)

            pbar.set_postfix(train_loss=train_loss, eval_loss=eval_loss)
            if is_ckpt_epoch(cur_epoch):
                save_ckpt(model, optimizer, scheduler, cur_epoch)

            if cfg.wandb.use:
                run.log(flatten_dict(perf), step=cur_epoch)

            # Log current best stats on eval epoch.
            if is_eval_epoch(cur_epoch):
                best_epoch = np.array([vp['loss'] for vp in val_perf]).argmin()
                best_epoch_loss = best_epoch

                best_train = best_val = best_test = ""
                if cfg.metric_best != 'auto':
                    # Select again based on val perf of `cfg.metric_best`.
                    m = cfg.metric_best
                    best_epoch = getattr(np.array([vp[m] for vp in val_perf]),
                                        cfg.metric_agg)()

                    if cfg.best_by_loss:
                        best_epoch = best_epoch_loss

                    if m in perf[0][best_epoch]:
                        best_train = f"train_{m}: {perf[0][best_epoch][m]:.4f}"
                    else:
                        best_train = f"train_{m}: {0:.4f}"
                    best_val = f"val_{m}: {perf[1][best_epoch][m]:.4f}"
                    best_test = f"test_{m}: {perf[2][best_epoch][m]:.4f}"

                    if cfg.wandb.use:
                        bstats = {"best/epoch": best_epoch}
                        for i, s in enumerate(['train', 'val', 'test']):
                            bstats[f"best/{s}_loss"] = perf[i][best_epoch]['loss']
                            if m in perf[i][best_epoch]:
                                bstats[f"best/{s}_{m}"] = perf[i][best_epoch][m]

                                run.summary[f"best_{s}_perf"] = \
                                    perf[i][best_epoch][m]


                            for x in ['hits@1', 'hits@3', 'hits@10', 'mrr']:
                                if x in perf[i][best_epoch]:
                                    bstats[f"best/{s}_{x}"] = perf[i][best_epoch][x]

                        run.log(bstats, step=cur_epoch)
                        run.summary["full_epoch_time_avg"] = np.mean(full_epoch_times)
                        run.summary["full_epoch_time_sum"] = np.sum(full_epoch_times)

                # Checkpoint the best epoch params (if enabled).
                if cfg.train.enable_ckpt and cfg.train.ckpt_best and \
                        best_epoch == cur_epoch:
                    if cur_epoch < cfg.optim.num_warmup_epochs:
                        pass
                    else:
                        save_ckpt(model, optimizer, scheduler, cur_epoch)
                    if cfg.train.ckpt_clean:  # Delete old ckpt each time.
                        clean_ckpt()
                logging.info(
                    f"> Epoch {cur_epoch}: took {full_epoch_times[-1]:.1f}s "
                    f"(avg {np.mean(full_epoch_times):.1f}s) | "
                    f"Best so far: epoch {best_epoch}\t"
                    f"train_loss: {perf[0][best_epoch]['loss']:.4f} {best_train}\t"
                    f"val_loss: {perf[1][best_epoch]['loss']:.4f} {best_val}\t" 
                    f"test_loss: {perf[2][best_epoch]['loss']:.4f} {best_test}\n"
                    f"-----------------------------------------------------------"
                )
                
    for logger in loggers:
        logger.close()
    if cfg.train.ckpt_clean:
        clean_ckpt()

    if cfg.train.save_final_model:
        ckpt: Dict[str, Any] = {}
        ckpt["model_state"] = model.state_dict()
        if optimizer is not None:
            ckpt["optimizer_state"] = optimizer.state_dict()
        if scheduler is not None:
            ckpt["scheduler_state"] = scheduler.state_dict()

        os.makedirs(cfg.out_dir, exist_ok=True)
        # if cfg.wandb.use:
        #     date = time.strftime("%Y%m%d-%H%M%S")
        #     torch.save(ckpt, cfg.out_dir + f"/{run.id}--{date}.ckpt")
        # else:
        date = time.strftime("%Y%m%d-%H%M%S")
        torch.save(ckpt, cfg.out_dir + f"/{date}_last.ckpt")
    
    #  TEST 
    test_loader = loaders[2]
    pred, true, embeddings = torch.tensor([]), torch.tensor([]), torch.tensor([])
    losses = []
    losses_gain = []
    losses_pm = []
    losses_bw = []
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(cfg.device)
            batch = add_full_rrwp(batch, walk_length=cfg.posenc_RRWP.ksteps)
            p, t, e = model(batch.clone(), return_embedding=True)
            loss_gain = mae_loss(p[:,0], t[:,0]).item()
            loss_pm = mae_loss(p[:,1], t[:,1]).item()
            loss_bw = mae_loss(p[:,2], t[:,2]).item()
            losses_gain.append(loss_gain)
            losses_pm.append(loss_pm)
            losses_bw.append(loss_bw)
            loss = loss_gain + loss_pm + loss_bw
            print(mae_loss(p,t))
            # loss = mae_loss(p, t).item()
            losses.append(loss)
            pred = torch.cat((pred, p.to("cpu")))
            true = torch.cat((true, t.to("cpu")))
            embeddings = torch.cat((embeddings, e.to("cpu")))
        print(f"MAE Loss: {np.mean(losses)}")
        print(f"MAE Loss Gain: {np.mean(losses_gain)}")
        print(f"MAE Loss PM: {np.mean(losses_pm)}")
        print(f"MAE Loss BW: {np.mean(losses_bw)}")

    logging.info('Task done, results saved in %s', cfg.run_dir)