import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import math
import pdb
import argparse
from lib.models.translate import TranslationMLP, SensorTimeEncoder
from lib.models.revin import RevIN
from lib.utils.checkpoint import EarlyStopping
from lib.utils.env import seed_all_rng
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from scipy import fft
from scipy import signal

"""
Run as follows:
python train_translation.py --mode en2z --cuda-id 2 --seed 2021
"""


def hilbert_envelope(x):
    """
    Args:
        x: tensor of shape (M, T, S)
    Returns:
        y: the hilbert envelope of x, shape (M, T, S)
    """
    y = signal.hilbert(x, axis=1)
    amp_env = np.abs(y)
    amp_env = torch.from_numpy(amp_env)
    return amp_env


def normalize(x):
    """
    Args:
        tensor: of shape (N, T, 3)
    """
    x -= x.mean(dim=1, keepdim=True)
    std = x.std(dim=1, keepdim=True)
    std[std == 0] = 1.0
    x /= std
    return x


def create_time_series_dataloader(
    datapath="/data/dummy/foo",
    batchsize=8,
    mode="en2z",
    scale=True,
):
    dataloaders = {}
    for split in ["train", "val", "test"]:
        # original time series
        time_file = os.path.join(datapath, "%s_data.npy" % split)
        time = np.load(time_file)  # (N, S, T)
        time = np.swapaxes(time, 1, 2)  # (N, T, S)
        time = torch.from_numpy(time).to(dtype=torch.float32)
        assert time.shape[-1] == 3

        # codes
        code_file = os.path.join(datapath, "%s_codes.npy" % (split))
        code = np.load(code_file)  # (N, S, T//C)
        code = np.swapaxes(code, 1, 2)  # (N, T // C, S)
        code = torch.from_numpy(code).to(dtype=torch.int64)
        assert code.shape[-1] == 3

        print("[Dataset][%s] %d of examples" % (split, time.shape[0]))

        if scale:
            time = normalize(time)

        if mode == "en2z":
            input_time = time[:, :, :2]
            output_time = time[:, :, 2].unsqueeze(2)
            output_time = hilbert_envelope(output_time)
            input_code = code[:, :, :2]
            output_code = code[:, :, 2].unsqueeze(2)
        elif mode == "z2en":
            input_time = time[:, :, 2].unsqueeze(2)
            output_time = time[:, :, :2]
            output_time = hilbert_envelope(output_time)
            input_code = code[:, :, 2].unsqueeze(2)
            output_code = code[:, :, :2]
        else:
            raise ValueError("Unkwown mode %s" % (mode))

        dataset = torch.utils.data.TensorDataset(
            input_time, input_code, output_time, output_code
        )
        dataloaders[split] = torch.utils.data.DataLoader(
            dataset,
            batch_size=batchsize,
            shuffle=True if split == "train" else False,
            num_workers=1,
            drop_last=True if split == "train" else False,
        )

    return dataloaders


def loss_fn(type, beta=1.0):
    if type == "mse":
        loss = nn.MSELoss()
    elif type == "smoothl1":
        loss = nn.SmoothL1Loss(beta=beta)
    else:
        raise ValueError("Invalid type")
    return loss


def train_one_epoch(
    dataloader,
    model,
    codebook,
    compression,
    optimizer,
    scheduler,
    epoch,
    device,
    loss_type: str = "smoothl1",
    beta: float = 1.0,
    plot=False,
):
    running_loss, last_loss = 0.0, 0.0
    log_every = max(len(dataloader) // 3, 3)

    if plot:
        plot_rows, plot_cols = 10, 3
        fig, axs = plt.subplots(plot_rows, plot_cols)
        plot_i = 0

    lossfn = loss_fn(loss_type, beta=beta)
    for i, data in enumerate(dataloader):
        # ----- LOAD DATA ------ #
        timex, codex, timey, _ = data
        # timex: (B, Tin, Sin)
        # codex: (B, Tin//C, Sin)
        # timey: (B, Tout, Sout)

        x = timex.to(device)
        y = timey.to(device)
        code_ids = codex.to(device)

        B, TCin, Sin = codex.shape
        code_ids = code_ids.flatten()
        xcodes = codebook[code_ids]  # (B*TCin*Sin, D)
        xcodes = xcodes.reshape((B, TCin, Sin, xcodes.shape[-1]))  # (B, TCin, Sin, D)

        _ = model.revin(x, "norm")

        if model.type == "mlptime":
            pred = model(x)
        elif model.type == "xformercode":
            scale = torch.cat(
                (model.revin.mean, model.revin.stdev), dim=1
            )  # (B, 2, Sin)
            scale = torch.permute(scale, (0, 2, 1))  # (B, Sin, 2)
            pred = model(xcodes, scale)
        else:
            raise NotImplementedError

        loss = lossfn(pred, y)

        # optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log
        running_loss += loss.item()
        if i % log_every == log_every - 1:
            last_loss = running_loss / log_every  # loss per batch
            lr = optimizer.param_groups[0]["lr"]
            # lr = scheduler.get_last_lr()[0]
            print(
                f"| epoch {epoch:3d} | {i+1:5d}/{len(dataloader):5d} batches | "
                f"lr {lr:02.5f} | loss {last_loss:5.3f}"
            )
            running_loss = 0.0

        if scheduler is not None:
            scheduler.step()

        if plot:
            if plot_i < plot_rows:
                axs[plot_i, 0].plot(
                    np.arange(model.Tout), y[0, :, 0].detach().cpu().numpy(), "b"
                )
                axs[plot_i, 1].plot(
                    np.arange(model.Tout),
                    pred[0, :, 0].detach().cpu().numpy(),
                    "r",
                )
                axs[plot_i, 2].plot(
                    np.arange(model.Tin), x[0].mean(dim=-1).detach().cpu().numpy(), "g"
                )
                plot_i += 1

    if plot:
        fig.savefig("/data/dummy/vis/train_translation_epoch%d.png" % (epoch))
        plt.close()


def inference(
    data,
    model,
    codebook,
    compression,
    device,
):
    """
    Returns:
        out: (B, Tout, Sout)
    """
    timex, codex, _, _ = data

    x = timex.to(device)
    code_ids = codex.to(device)

    B, TCin, Sin = codex.shape
    code_ids = code_ids.flatten()
    xcodes = codebook[code_ids]  # (B*TCin*Sin, D)
    xcodes = xcodes.reshape((B, TCin, Sin, xcodes.shape[-1]))  # (B, TCin, Sin, D)

    _ = model.revin(x, "norm")

    if model.type == "mlptime":
        pred = model(x)
    elif model.type == "xformercode":
        scale = torch.cat((model.revin.mean, model.revin.stdev), dim=1)  # (B, 2, Sin)
        scale = torch.permute(scale, (0, 2, 1))  # (B, Sin, 2)
        pred = model(xcodes, scale)
    else:
        raise NotImplementedError

    return pred


def train(args, plot=False):
    device = torch.device("cuda:%d" % (args.cuda_id))
    torch.cuda.set_device(device)

    # -------- SET SEED ------- #
    seed_all_rng(None if args.seed < 0 else args.seed)

    # -------- PARAMS ------- #
    params = get_params(args)
    batchsize = params["batchsize"]
    datapath = params["dataroot"]
    Sin = params["Sin"]
    Sout = params["Sout"]
    expname = "translation_mode%s_model%s" % (args.mode, args.model_type)

    # -------- CHECKPOINT ------- #
    checkpath = None
    if args.checkpoint:
        checkpath = os.path.join(args.checkpoint_path, expname)
        os.makedirs(checkpath, exist_ok=True)
    early_stopping = EarlyStopping(patience=args.patience, path=checkpath)
    # -------- CODEBOOK ------- #
    codebook = np.load(os.path.join(datapath, "codebook.npy"), allow_pickle=True)
    codebook = torch.from_numpy(codebook).to(device=device, dtype=torch.float32)
    vocab_size, vocab_dim = codebook.shape
    assert vocab_size == args.codebook_size
    dim = vocab_dim

    # ------ DATA LOADERS ------- #
    dataloaders = create_time_series_dataloader(
        datapath=datapath,
        batchsize=batchsize,
        mode=args.mode,
        scale=True,
    )
    train_dataloader = dataloaders["train"]
    val_dataloader = dataloaders["val"]
    test_dataloader = dataloaders["test"]

    # ------- MODEL -------- #
    if args.model_type == "mlptime":
        model = TranslationMLP(
            Sin=Sin,
            Sout=Sout,
            Tin=args.Tin,
            Tout=args.Tout,
            hidden_dims=[1024, 1024, 1024],
            dropout=0.2,
        )
    elif args.model_type == "xformercode":
        model = SensorTimeEncoder(
            d_in=dim,
            d_model=args.d_model,
            nheadt=args.nhead,
            nheads=args.nhead,
            d_hid=args.d_hid,
            nlayerst=args.nlayers,
            nlayerss=args.nlayers // 2,
            time_in=args.Tin,
            time_out=args.Tout,
            compression=args.compression,
            sens_in=Sin,
            sens_out=Sout,
            dropout=0.25,
        )
    else:
        raise NotImplementedError("Not implemented yet")
    model.type = args.model_type
    model.revin = RevIN(num_features=Sin, affine=False)  # expects as input (B, T, S)
    model.to(device)

    # ------- OPTIMIZER -------- #
    num_iters = args.epochs * len(train_dataloader)
    step_lr_in_iters = args.steps * len(train_dataloader)
    model_params = list(model.parameters())
    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model_params, lr=args.baselr, momentum=0.9)
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params, lr=args.baselr)
    elif args.optimizer == "adamw":
        optimizer = torch.optim.AdamW(model_params, lr=args.baselr)
    else:
        raise ValueError("Uknown optimizer type %s" % (args.optimizer))
    if args.scheduler == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_lr_in_iters, gamma=0.1
        )
    elif args.scheduler == "onecycle":
        # The learning rate will increate from max_lr / div_factor to max_lr in the first pct_start * total_steps steps,
        # and decrease smoothly to max_lr / final_div_factor then.
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=args.baselr,
            steps_per_epoch=len(train_dataloader),
            epochs=args.epochs,
            pct_start=0.2,
        )
    else:
        raise ValueError("Uknown scheduler type %s" % (args.scheduler))

    # ------- TRAIN & EVAL -------- #
    for epoch in range(args.epochs):
        model.train()
        train_one_epoch(
            train_dataloader,
            model,
            codebook,
            args.compression,
            optimizer,
            scheduler,
            epoch,
            device,
            beta=args.beta,
            plot=plot,
        )

        if val_dataloader is not None:
            model.eval()
            running_mse, running_mae = 0.0, 0.0
            total_num = 0.0
            # Disable gradient computation and reduce memory consumption.
            with torch.no_grad():
                for i, vdata in enumerate(val_dataloader):
                    pred_time = inference(
                        vdata,
                        model,
                        codebook,
                        args.compression,
                        device,
                    )
                    timey = vdata[2].to(device)

                    running_mse += F.mse_loss(pred_time, timey, reduction="sum")
                    running_mae += (pred_time - timey).abs().sum()
                    total_num += timey.numel()  # B * S * T
            running_mae = running_mae / total_num
            running_mse = running_mse / total_num
            print(f"| [Val] mse {running_mse:5.3f} mae {running_mae:5.3f}")

            early_stopping(running_mse, running_mae, {"model": model})

        if test_dataloader is not None:
            model.eval()
            running_mse, running_mae = 0.0, 0.0
            total_num = 0.0
            if plot:
                plot_rows, plot_cols = 10, 3
                fig, axs = plt.subplots(plot_rows, plot_cols)
                plot_i = 0
            # Disable gradient computation and reduce memory consumption.
            with torch.no_grad():
                for i, tdata in enumerate(test_dataloader):
                    pred_time = inference(
                        tdata,
                        model,
                        codebook,
                        args.compression,
                        device,
                    )
                    timey = tdata[2].to(device)
                    timex = tdata[0].to(device)

                    running_mse += F.mse_loss(pred_time, timey, reduction="sum")
                    running_mae += (pred_time - timey).abs().sum()
                    total_num += timey.numel()
                    if plot:
                        if plot_i < plot_rows:
                            axs[plot_i, 0].plot(
                                np.arange(model.Tout),
                                timey[0, :, 0].cpu().numpy(),
                                "b",
                            )
                            axs[plot_i, 1].plot(
                                np.arange(model.Tout),
                                pred_time[0, :, 0].cpu().numpy(),
                                "r",
                            )
                            axs[plot_i, 2].plot(
                                np.arange(model.Tin),
                                timex[0].mean(dim=-1).cpu().numpy(),
                                "g",
                            )
                            plot_i += 1
            running_mae = running_mae / total_num
            running_mse = running_mse / total_num
            print(f"| [Test] mse {running_mse:5.3f} mae {running_mae:5.3f}")

            if plot:
                fig.savefig("/data/dummy/vis/test_translation_epoch%d.png" % (epoch))
                plt.close()
            if early_stopping.early_stop:
                print("Early stopping....")
                return


def get_params(args):
    if args.data_type == "random":
        batchsize = 64
        dataroot = "/data/dummy/mts_v2_datasets/earthquake_clean_randomsplit/"
    elif args.data_type == "zeroshot":
        batchsize = 64
        dataroot = "/data/dummy/mts_v2_datasets/earthquake_clean_0shotsplit"
    else:
        raise ValueError("Unknown data type %s" % (args.data_type))

    if args.mode == "en2z":
        Sin, Sout = 2, 1
    elif args.mode == "z2en":
        Sin, Sout = 1, 2
    else:
        raise ValueError("Uknown mode %s" % (args.mode))
    return {"dataroot": dataroot, "batchsize": batchsize, "Sin": Sin, "Sout": Sout}


def default_argument_parser():
    """
    Create a parser.

    Returns:
        argparse.ArgumentParser:
    """
    parser = argparse.ArgumentParser(description="Code Prediction")
    # TODO add support for resume
    parser.add_argument(
        "--resume",
        action="store_true",
        help="whether to attempt to resume from the checkpoint directory",
    )
    parser.add_argument("--cuda-id", default=0, type=int)
    # ---------- SEED ---------- #
    # Set seed to negative to fully randomize everything.
    # Set seed to positive to use a fixed seed. Note that a fixed seed increases
    # reproducibility but does not guarantee fully deterministic behavior.
    parser.add_argument("--seed", default=-1, type=int)
    # ---------- MODEL ---------- #
    parser.add_argument(
        "--model-type",
        default="mlptime",
        type=str,
        help="One of 'mlptime' or 'xformercode'.",
    )
    # ---------- DATA ---------- #
    parser.add_argument(
        "--data-type", default="random", type=str, help="One of 'random' or 'zeroshot'."
    )
    parser.add_argument(
        "--mode", default="en2z", type=str, help="One of 'en2z' or 'z2en'."
    )
    parser.add_argument("--codebook_size", default=256, type=int)
    parser.add_argument("--compression", default=4, type=int)
    parser.add_argument("--Tin", default=3000, type=int)
    parser.add_argument("--Tout", default=3000, type=int)
    # ----------- CHECKPOINT ------------ #
    parser.add_argument("--checkpoint", action="store_true")
    parser.add_argument(
        "--checkpoint_path", default="/data/dummy/experiments", type=str
    )
    parser.add_argument("--patience", default=3, type=int)
    # ---------- TRAINING ---------- #
    parser.add_argument("--optimizer", default="adam", type=str)
    parser.add_argument("--scheduler", default="onecycle", type=str)
    parser.add_argument("--baselr", default=0.0001, type=float)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--steps", default=4, type=int, help="decay LR in epochs")
    parser.add_argument(
        "--beta", default=1.0, type=float, help="beta for smoothl1 loss"
    )
    # ----------- ENCODER PARAMS ------------ #
    parser.add_argument("--d-model", default=64, type=int)
    parser.add_argument("--d_hid", default=256, type=int)
    parser.add_argument("--nhead", default=4, type=int)
    parser.add_argument("--nlayers", default=4, type=int)

    return parser.parse_args()


if __name__ == "__main__":
    args = default_argument_parser()
    train(args, plot=False)
