import argparse
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import datetime
import time
import matplotlib.pyplot as plt
from torchinfo import summary
import yaml
import json
import sys
import copy
import util

sys.path.append("..")
from mylib.utils import (
    MaskedMAELoss,
    print_log,
    seed_everything,
    set_cpu_num,
    CustomJSONEncoder,
)
from mylib.metrics import RMSE_MAE_MAPE
from mylib.data_prepare import get_dataloaders_from_index_data
from mylib.data_prepare import load_adj
import mylib.data_prepare as data_prepare
from model.MvHSTM import MvHSTM


@torch.no_grad()
def eval_model(model, valset_loader, criterion):
    model.eval()
    batch_loss_list = []
    for x_batch, y_batch in valset_loader:
        x_batch = x_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)

        out_batch = model(x_batch)
        out_batch = SCALER.inverse_transform(out_batch)
        loss = criterion(out_batch, y_batch)
        batch_loss_list.append(loss.item())

    return np.mean(batch_loss_list)


@torch.no_grad()
def predict(model, loader):
    model.eval()
    y = []
    out = []

    for x_batch, y_batch in loader:
        x_batch = x_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)

        out_batch = model(x_batch)
        out_batch = SCALER.inverse_transform(out_batch)

        out_batch = out_batch.cpu().numpy()
        y_batch = y_batch.cpu().numpy()
        out.append(out_batch)
        y.append(y_batch)

    out = np.vstack(out).squeeze()
    y = np.vstack(y).squeeze()

    return y, out


def train_one_epoch(
        model, trainset_loader, optimizer, scheduler, criterion, clip_grad, log=None
):
    global cfg, global_iter_count, global_target_length

    model.train()
    batch_loss_list = []
    for x_batch, y_batch in trainset_loader:
        x_batch = x_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE).double()
        out_batch = model(x_batch)
        out_batch = SCALER.inverse_transform(out_batch)

        loss = criterion(out_batch, y_batch)
        batch_loss_list.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        if clip_grad:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        optimizer.step()

    epoch_loss = np.mean(batch_loss_list)
    scheduler.step()
    return epoch_loss


def train(
        model,
        trainset_loader,
        valset_loader,
        optimizer,
        scheduler,
        criterion,
        clip_grad=0,
        max_epochs=200,
        early_stop=10,
        verbose=1,
        plot=False,
        log=None,
        save=None,
):
    model = model.to(DEVICE)

    wait = 0
    min_val_loss = np.inf

    train_loss_list = []
    val_loss_list = []

    for epoch in range(max_epochs):
        train_loss = train_one_epoch(
            model, trainset_loader, optimizer, scheduler, criterion, clip_grad, log=log
        )
        train_loss_list.append(train_loss)

        val_loss = eval_model(model, valset_loader, criterion)
        val_loss_list.append(val_loss)

        if (epoch + 1) % verbose == 0:
            print_log(
                datetime.datetime.now(),
                "Epoch",
                epoch + 1,
                " \tTrain Loss = %.5f" % train_loss,
                "Val Loss = %.5f" % val_loss,
                log=log,
            )

        if val_loss < min_val_loss:
            wait = 0
            min_val_loss = val_loss
            best_epoch = epoch
            best_state_dict = copy.deepcopy(model.state_dict())
        else:
            wait += 1
            if wait >= early_stop:
                break

    model.load_state_dict(best_state_dict)
    train_rmse, train_mae, train_mape = RMSE_MAE_MAPE(*predict(model, trainset_loader))
    val_rmse, val_mae, val_mape = RMSE_MAE_MAPE(*predict(model, valset_loader))

    out_str = f"Early stopping at epoch: {epoch + 1}\n"
    out_str += f"Best at epoch {best_epoch + 1}:\n"
    out_str += "Train Loss = %.5f\n" % train_loss_list[best_epoch]
    out_str += "Train RMSE = %.5f, MAE = %.5f, MAPE = %.5f\n" % (
        train_rmse,
        train_mae,
        train_mape,
    )
    out_str += "Val Loss = %.5f\n" % val_loss_list[best_epoch]
    out_str += "Val RMSE = %.5f, MAE = %.5f, MAPE = %.5f" % (
        val_rmse,
        val_mae,
        val_mape,
    )
    print_log(out_str, log=log)

    if plot:
        plt.plot(range(0, epoch + 1), train_loss_list, "-", label="Train Loss")
        plt.plot(range(0, epoch + 1), val_loss_list, "-", label="Val Loss")
        plt.title("Epoch-Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()

    if save:
        torch.save(best_state_dict, save)
    return model


@torch.no_grad()
def test_model(model, testset_loader, log=None):
    model.eval()
    print_log("--------- Test ---------", log=log)

    start = time.time()
    y_true, y_pred = predict(model, testset_loader)
    end = time.time()

    rmse_all, mae_all, mape_all = RMSE_MAE_MAPE(y_true, y_pred)
    out_str = "All Steps RMSE = %.5f, MAE = %.5f, MAPE = %.5f\n" % (
        rmse_all,
        mae_all,
        mape_all,
    )
    out_steps = y_pred.shape[1]
    for i in range(out_steps):
        rmse, mae, mape = RMSE_MAE_MAPE(y_true[:, i, :], y_pred[:, i, :])
        out_str += "Step %d RMSE = %.5f, MAE = %.5f, MAPE = %.5f\n" % (
            i + 1,
            rmse,
            mae,
            mape,
        )

    print_log(out_str, log=log, end="")
    print_log("Inference time: %.2f s" % (end - start), log=log)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, default="pems08")
    parser.add_argument("-m", "--model", type=str, default="STAEformer")
    parser.add_argument("-g", "--gpu_num", type=int, default=0)
    parser.add_argument("-n", "--n_cluster", type=int, default=4)
    args = parser.parse_args()

    seed = torch.randint(1000, (1,))
    seed_everything(seed)
    set_cpu_num(1)

    GPU_ID = args.gpu_num
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{GPU_ID}"
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataset = args.dataset
    dataset = dataset.upper()
    data_path = f"../data/{dataset}"

    if args.model == "MvHSTM":
        model_name = MvHSTM.__name__
        model_class = MvHSTM
    else:
        raise ValueError(f"Model {args.model} is not supported.")

    with open(f"{model_name}.yaml", "r") as f:
        cfg = yaml.safe_load(f)
    cfg = cfg[dataset]



    now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    log_path = f"../logs/"
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    log = os.path.join(log_path, f"{model_name}-{dataset}-{now}.log")
    log = open(log, "a")
    log.seek(0)
    log.truncate()



    print_log(dataset, log=log)
    (
        trainset_loader,
        valset_loader,
        testset_loader,
        SCALER,
    ) = get_dataloaders_from_index_data(
        data_path,
        tod=cfg.get("time_of_day"),
        dow=cfg.get("day_of_week"),
        batch_size=cfg.get("batch_size", 64),
        log=log,
    )
    print_log(log=log)

    all_inputs = []
    all_targets = []

    for inputs, targets in trainset_loader:
        all_inputs.append(inputs.cpu().numpy())
        all_targets.append(targets.cpu().numpy())
    all_inputs = np.concatenate(all_inputs, axis=0)
    all_inputs = all_inputs[:, 0, :, 0]
    node_columns = [f"node_{i}" for i in range(all_inputs.shape[1])]
    train_df = pd.DataFrame(all_inputs, columns=node_columns)


    if args.model == 'Nformer':
        adj_mx, adj_origin = load_adj(data_path)
        hyper_graph_lap, line_graph_lap = data_prepare.generate_laplacians(train_df, adj_origin, args.n_cluster, DEVICE)
    elif args.model == 'MvHSTM':
        adj_mx, adj_origin = load_adj(data_path)
        supports = [torch.tensor(i).to(DEVICE) for i in adj_mx]
        spatial_H_a, spatial_H_b, spatial_H_T_new, spatial_lwjl, spatial_G0, spatial_G1 = util.load_hadj(data_path, 4)
        spatial_lwjl = (((spatial_lwjl.t()).unsqueeze(0)).unsqueeze(3)).repeat(cfg.get("batch_size"), 1, 1, 1)
        print("beginning clustering")
        semantic_H, semantic_H_T_new, semantic_G0, semantic_G1 = util.load_shadj(all_inputs, adj_origin, args.n_cluster)
        print("clustering finished")

        spatial_H_a = spatial_H_a.to(DEVICE)
        spatial_H_b = spatial_H_b.to(DEVICE)
        spatial_G0 = torch.tensor(spatial_G0).to(DEVICE)
        spatial_G1 = torch.tensor(spatial_G1).to(DEVICE)
        spatial_H_T_new = torch.tensor(spatial_H_T_new).to(DEVICE)
        spatial_lwjl = spatial_lwjl.to(DEVICE)

        semantic_H = torch.tensor(semantic_H).to(DEVICE)
        semantic_H_T_new = torch.tensor(semantic_H_T_new).to(DEVICE)
        semantic_G0 = torch.tensor(semantic_G0).to(DEVICE)
        semantic_G1 = torch.tensor(semantic_G1).to(DEVICE)

        hyper_graph_lap = 0
        line_graph_lap = 0

    else:
        hyper_graph_lap = 0
        line_graph_lap = 0



    spatial_hypergraph = {}


    spatial_hypergraph['hyper_graph_lap'] = hyper_graph_lap
    spatial_hypergraph['line_graph_lap'] = line_graph_lap


    spatial_hypergraph.update(cfg["model_args"])

    if args.model == 'MvHSTM':
        spatial_hypergraph.update({
            'device': DEVICE,
            'supports': supports,
            'spatial_H_a': spatial_H_a,
            'spatial_H_b': spatial_H_b,
            'spatial_G0': spatial_G0,
            'spatial_G1': spatial_G1,
            'spatial_H_T_new': spatial_H_T_new,
            'spatial_lwjl': spatial_lwjl,
            'semantic_H': semantic_H,
            'semantic_H_T_new': semantic_H_T_new,
            'semantic_G0': semantic_G0,
            'semantic_G1': semantic_G1
        })



    if args.model == 'MvHSTM':
        model = model_class(**spatial_hypergraph)

    else:
        model = model_class(hyper_graph_lap=hyper_graph_lap, line_graph_lap=line_graph_lap, **cfg["model_args"])



    save_path = f"../saved_models/"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    save = os.path.join(save_path, f"{model_name}-{dataset}-{now}.pt")



    if dataset in ("METRLA", "PEMSBAY"):
        criterion = MaskedMAELoss()
    else:
        raise ValueError("Unsupported dataset.")

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=cfg["lr"],
        weight_decay=cfg.get("weight_decay", 0),
        eps=cfg.get("eps", 1e-8),
    )
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=cfg["milestones"],
        gamma=cfg.get("lr_decay_rate", 0.1),
        verbose=False,
    )



    print_log("---------", model_name, "---------", log=log)
    print_log(
        json.dumps(cfg, ensure_ascii=False, indent=4, cls=CustomJSONEncoder), log=log
    )
    print_log(
        summary(
            model,
            [
                cfg["batch_size"],
                cfg["in_steps"],
                cfg["num_nodes"],
                next(iter(trainset_loader))[0].shape[-1],
            ],
            verbose=0,
        ),
        log=log,
    )
    print_log(log=log)



    print_log(f"Loss: {criterion._get_name()}", log=log)
    print_log(log=log)

    model = train(
        model,
        trainset_loader,
        valset_loader,
        optimizer,
        scheduler,
        criterion,
        clip_grad=cfg.get("clip_grad"),
        max_epochs=cfg.get("max_epochs", 200),
        early_stop=cfg.get("early_stop", 10),
        verbose=1,
        log=log,
        save=save,
    )

    print_log(f"Saved Model: {save}", log=log)

    test_model(model, testset_loader, log=log)

    log.close()
