import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import math
import pdb
import argparse
from lib.models.revin import RevIN
from lib.models.zeroshot_classif import SimpleMLP, EEGNet, SensorTimeEncoder
from lib.utils.env import seed_all_rng


def create_dataloader(datapaths, batchsize=8):
    datapath, datapath2, datapath3, S2 = datapaths["pt2"]
    pt2x, pt2codes, pt2y = [], [], []
    for split in ["train"]:
        x_file = os.path.join(datapath, "%s_x_original.npy" % (split))
        x = np.load(x_file)
        x = torch.from_numpy(x).to(dtype=torch.float32)
        pt2x.append(x)

        y_file = os.path.join(datapath2, "%s_labels.npy" % (split))
        y = np.load(y_file)
        y = torch.from_numpy(y).to(dtype=torch.float32)
        pt2y.append(y)

        codes_file = datapath3 + "%s_x_codes.npy" % (split)
        codes = np.load(codes_file)
        codes = torch.from_numpy(codes).to(dtype=torch.int64)
        pt2codes.append(codes)

    pt2x = torch.cat(pt2x, dim=0)  # (N, T, S)
    pt2y = torch.cat(pt2y, dim=0)  # (N, 1)
    pt2codes = torch.cat(pt2codes, dim=0)  # (N, T//C, S)
    mask2 = torch.ones((pt2x.shape[0], S2))
    assert pt2x.shape[-1] == S2

    datapath, datapath2, datapath3, S5 = datapaths["pt5"]
    pt5x, pt5codes, pt5y = [], [], []
    for split in ["train"]:
        x_file = os.path.join(datapath, "%s_x_original.npy" % (split))
        x = np.load(x_file)
        x = torch.from_numpy(x).to(dtype=torch.float32)
        pt5x.append(x)

        y_file = os.path.join(datapath2, "%s_labels.npy" % (split))
        y = np.load(y_file)
        y = torch.from_numpy(y).to(dtype=torch.float32)
        pt5y.append(y)

        codes_file = datapath3 + "%s_x_codes.npy" % (split)
        codes = np.load(codes_file)
        codes = torch.from_numpy(codes).to(dtype=torch.int64)
        pt5codes.append(codes)

    pt5x = torch.cat(pt5x, dim=0)  # (N, T, S)
    pt5y = torch.cat(pt5y, dim=0)  # (N, 1)
    pt5codes = torch.cat(pt5codes, dim=0)  # (N, T//C, S)
    mask5 = torch.ones((pt5x.shape[0], S5))
    assert pt5x.shape[-1] == S5

    datapath, datapath2, datapath3, S12 = datapaths["pt12"]
    pt12x, pt12codes, pt12y = [], [], []
    for split in ["test"]:
        x_file = os.path.join(datapath, "%s_x_original.npy" % (split))
        x = np.load(x_file)
        x = torch.from_numpy(x).to(dtype=torch.float32)
        pt12x.append(x)

        y_file = os.path.join(datapath2, "%s_labels.npy" % (split))
        y = np.load(y_file)
        y = torch.from_numpy(y).to(dtype=torch.float32)
        pt12y.append(y)

        codes_file = datapath3 + "%s_x_codes.npy" % (split)
        codes = np.load(codes_file)
        codes = torch.from_numpy(codes).to(dtype=torch.int64)
        pt12codes.append(codes)

    pt12x = torch.cat(pt12x, dim=0)  # (N, T, S)
    pt12y = torch.cat(pt12y, dim=0)  # (N, 1)
    pt12codes = torch.cat(pt12codes, dim=0)  # (N, T//C, S)
    mask12 = torch.ones((pt12x.shape[0], S12))
    assert pt12x.shape[-1] == S12

    maxS = max(S2, S5, S12)

    # padding
    pt2x = F.pad(pt2x, (0, maxS - S2))
    pt2codes = F.pad(pt2codes, (0, maxS - S2))
    mask2 = F.pad(mask2, (0, maxS - S2))

    pt5x = F.pad(pt5x, (0, maxS - S5))
    pt5codes = F.pad(pt5codes, (0, maxS - S5))
    mask5 = F.pad(mask5, (0, maxS - S5))

    pt12x = F.pad(pt12x, (0, maxS - S12))
    pt12codes = F.pad(pt12codes, (0, maxS - S12))
    mask12 = F.pad(mask12, (0, maxS - S12))

    train_dataset = torch.utils.data.TensorDataset(
        torch.cat((pt2x, pt5x)),
        torch.cat((pt2codes, pt5codes)),
        torch.cat((pt2y, pt5y)),
        torch.cat((mask2, mask5)),
    )
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batchsize,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )

    num = pt2x.shape[0] + pt5x.shape[0]
    ones = (torch.cat((pt2y, pt5y)) == 1).sum() / num
    zeros = (torch.cat((pt2y, pt5y)) == 0).sum() / num
    print("[Train] %d examples / %.2f 0s - %.2f 1s" % (num, zeros, ones))

    test_dataset = torch.utils.data.TensorDataset(
        pt12x,
        pt12codes,
        pt12y,
        mask12,
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batchsize,
        shuffle=False,
        num_workers=1,
        drop_last=False,
    )

    num = pt12x.shape[0]
    ones = (pt12y == 1).sum() / num
    zeros = (pt12y == 0).sum() / num
    print("[Test] %d examples / %.2f 0s - %.2f 1s" % (num, zeros, ones))

    return train_dataloader, test_dataloader


def train_one_epoch(
    dataloader,
    model,
    codebook,
    optimizer,
    scheduler,
    epoch,
    device,
):
    running_loss, last_loss = 0.0, 0.0
    running_acc, last_acc = 0.0, 0.0
    log_every = max(len(dataloader) // 3, 3)

    for i, data in enumerate(dataloader):
        # ----- LOAD DATA ------ #
        x, code_ids, y, mask = data
        # x: (B, T, S)
        # codes: (B, TC, S)  TC = T // compression
        # y: (B,)
        x = x.to(device)
        code_ids = code_ids.to(device)
        y = y.to(device)
        mask = mask.to(device)

        # reshape data
        B, T, S = x.shape
        B, TC, S = code_ids.shape

        # get codewords for input x
        code_ids = code_ids.flatten()
        xcodes = codebook[code_ids]  # (B*TC*S, D)
        xcodes = xcodes.reshape((B, TC, S, xcodes.shape[-1]))  # (B, TC, S, D)

        # revin time series
        norm_x = model.revin(x, "norm")

        if isinstance(model, SimpleMLP):
            x = x.flatten(start_dim=1)
            predy = model(x)
        elif isinstance(model, EEGNet):
            x = torch.permute(x, (0, 2, 1))  # (B, S, T)
            x = x.unsqueeze(1)  # (B, 1, S, T)
            predy = model(x)
        elif isinstance(model, SensorTimeEncoder):
            scale = torch.cat((model.revin.mean, model.revin.stdev), dim=1)
            scale = torch.permute(scale, (0, 2, 1))
            predy = model(xcodes, scale, mask == 0)
        else:
            raise ValueError("womp womp")

        loss = F.binary_cross_entropy_with_logits(predy, y)

        # optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log
        running_loss += loss.item()
        with torch.no_grad():
            running_acc += ((predy.sigmoid() > 0.5) == y).float().mean()
        if i % log_every == log_every - 1:
            last_loss = running_loss / log_every  # loss per batch
            last_acc = running_acc / log_every
            lr = optimizer.param_groups[0]["lr"]
            # lr = scheduler.get_last_lr()[0]
            print(
                f"| epoch {epoch:3d} | {i+1:5d}/{len(dataloader):5d} batches | "
                f"lr {lr:02.5f} | loss {last_loss:5.3f} | acc {last_acc:5.3f}"
            )
            running_loss = 0.0
            running_acc = 0.0

        if scheduler is not None:
            scheduler.step()


def inference(
    data,
    model,
    codebook,
    device,
):
    x, code_ids, _, mask = data
    x = x.to(device)
    code_ids = code_ids.to(device)
    mask = mask.to(device)

    # reshape data
    B, T, S = x.shape
    B, TC, S = code_ids.shape

    # get codewords for input x
    code_ids = code_ids.flatten()
    xcodes = codebook[code_ids]  # (B*TC*S, D)
    xcodes = xcodes.reshape((B, TC, S, xcodes.shape[-1]))  # (B, TC, S, D)

    # revin time series
    norm_x = model.revin(x, "norm")

    if isinstance(model, SimpleMLP):
        x = x.flatten(start_dim=1)
        predy = model(x)
    elif isinstance(model, EEGNet):
        x = torch.permute(x, (0, 2, 1))  # (B, S, T)
        x = x.unsqueeze(1)  # (B, 1, S, T)
        predy = model(x)
    elif isinstance(model, SensorTimeEncoder):
        scale = torch.cat((model.revin.mean, model.revin.stdev), dim=1)
        scale = torch.permute(scale, (0, 2, 1))
        predy = model(xcodes, scale, mask == 0)
    else:
        raise ValueError("wamp wamp")

    return predy


def train(args):
    device = torch.device("cuda:%d" % (args.cuda_id))
    torch.cuda.set_device(device)

    # -------- SET SEED ------- #
    seed_all_rng(None if args.seed < 0 else args.seed)

    # -------- CHECKPOINT ------- #
    checkpath = None
    if args.checkpoint:
        checkpath = os.path.join(args.checkpoint_path, expname)
        os.makedirs(checkpath, exist_ok=True)

    # ------ DATA LOADERS ------- #
    datapaths, codebookpath = get_datapaths()
    batchsize = 16

    train_dataloader, test_dataloader = create_dataloader(
        datapaths,
        batchsize=batchsize,
    )

    # -------- MAXIMUM ELECTRODES ----- #
    maxS = 106

    # -------- CODEBOOK ------- #
    codebook = np.load(os.path.join(codebookpath, "codebook.npy"), allow_pickle=True)
    codebook = torch.from_numpy(codebook).to(device=device, dtype=torch.float32)
    vocab_size, vocab_dim = codebook.shape
    assert vocab_size == args.codebook_size
    dim = vocab_size if args.onehot else vocab_dim

    # ------- MODEL -------- #
    if args.model_type == "mlp":
        # time --> class (baseline)
        model = SimpleMLP(
            in_dim=maxS * args.Tin, out_dim=1, hidden_dims=[1024, 512, 256], dropout=0.0
        )
    elif args.model_type == "eeg":
        # time --> class (baseline)
        model = EEGNet(
            chunk_size=args.Tin,
            num_electrodes=maxS,
            F1=8,
            F2=16,
            D=2,
            kernel_1=64,
            kernel_2=16,
            dropout=0.25,
        )
    elif args.model_type == "xformer":
        # code --> class (ours)
        model = SensorTimeEncoder(
            d_in=dim,
            d_model=args.d_model,
            nheadt=args.nhead,
            nheads=args.nhead,
            d_hid=args.d_hid,
            nlayerst=args.nlayers,
            nlayerss=args.nlayers,
            seq_lent=args.Tin // args.compression,
            seq_lens=maxS,
            dropout=0.25,
        )
    else:
        raise ValueError("Unknown model type %s" % (args.model_type))
    model.revin = RevIN(num_features=maxS, affine=False)  # expects as input (B, T, S)
    model.to(device)

    # ------- OPTIMIZER -------- #
    num_iters = args.epochs * len(train_dataloader)
    step_lr_in_iters = args.steps * len(train_dataloader)
    model_params = list(model.parameters())
    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model_params, lr=args.baselr, momentum=0.9)
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params, lr=args.baselr)
    elif args.optimizer == "adamw":
        optimizer = torch.optim.AdamW(model_params, lr=args.baselr)
    else:
        raise ValueError("Uknown optimizer type %s" % (args.optimizer))
    if args.scheduler == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_lr_in_iters, gamma=0.1
        )
    elif args.scheduler == "onecycle":
        # The learning rate will increate from max_lr / div_factor to max_lr in the first pct_start * total_steps steps,
        # and decrease smoothly to max_lr / final_div_factor then.
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=args.baselr,
            steps_per_epoch=len(train_dataloader),
            epochs=args.epochs,
            pct_start=0.2,
        )
    else:
        raise ValueError("Uknown scheduler type %s" % (args.scheduler))

    # ------- TRAIN & EVAL -------- #
    for epoch in range(args.epochs):
        model.train()
        train_one_epoch(
            train_dataloader,
            model,
            codebook,
            optimizer,
            scheduler,
            epoch,
            device,
        )

        if test_dataloader is not None:
            model.eval()
            running_acc = 0.0
            total_num = 0.0
            # Disable gradient computation and reduce memory consumption.
            with torch.no_grad():
                for i, tdata in enumerate(test_dataloader):
                    pred = inference(
                        tdata,
                        model,
                        codebook,
                        device,
                    )
                    y = tdata[-2]
                    y = y.to(device)

                    running_acc += ((pred.sigmoid() > 0.5) == y).sum()
                    total_num += y.numel()
            running_acc = running_acc / total_num
            print(f"| [Test] acc {running_acc:5.3f}")


def get_datapaths():
    datapaths = {}

    # pt2
    datapath = "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/pt2/BS_4096"
    datapath2 = "/data/dummy/mts_v2_datasets/pipeline/neuro_tr_v_test_before_revin/pt2"
    datapath3 = "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/neuro_combined_train2and5_test12/BS_4096/pt2_"
    S = 72
    datapaths["pt2"] = [datapath, datapath2, datapath3, S]

    # pt5
    datapath = "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/pt5/BS_4096"
    datapath2 = "/data/dummy/mts_v2_datasets/pipeline/neuro_tr_v_test_before_revin/pt5"
    datapath3 = "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/neuro_combined_train2and5_test12/BS_4096/pt5_"
    S = 106
    datapaths["pt5"] = [datapath, datapath2, datapath3, S]

    # pt12
    datapath = "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/pt12/BS_4096"
    datapath2 = (
        "/data/dummy/mts_v2_datasets/pipeline/neuro_tr_v_test_before_revin/pt12"
    )
    datapath3 = "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/neuro_combined_train2and5_test12/BS_4096/pt12_"
    S = 93
    datapaths["pt12"] = [datapath, datapath2, datapath3, S]

    codebookpath = "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/neuro_combined_train2and5_test12/BS_4096/"
    return datapaths, codebookpath


def default_argument_parser():
    """
    Create a parser.

    Returns:
        argparse.ArgumentParser:
    """
    parser = argparse.ArgumentParser(description="Classification")
    parser.add_argument(
        "--resume",
        action="store_true",
        help="whether to attempt to resume from the checkpoint directory",
    )
    parser.add_argument("--cuda-id", default=0, type=int)
    # ---------- SEED ---------- #
    # Set seed to negative to fully randomize everything.
    # Set seed to positive to use a fixed seed. Note that a fixed seed increases
    # reproducibility but does not guarantee fully deterministic behavior.
    parser.add_argument("--seed", default=-1, type=int)
    # ---------- MODEL ---------- #
    parser.add_argument(
        "--model-type",
        default="eeg",
        type=str,
        help="The model, either 'eeg' or 'xformer'.",
    )
    # ---------- DATA ---------- #
    parser.add_argument("--codebook_size", default=256, type=int)
    parser.add_argument("--compression", default=4, type=int)
    parser.add_argument("--Tin", default=1001, type=int)
    # ----------- CHECKPOINT ------------ #
    parser.add_argument("--checkpoint", default=0, type=bool)
    parser.add_argument(
        "--checkpoint_path", default="/data/dummy/experiments", type=str
    )
    parser.add_argument("--patience", default=7, type=int)
    # ---------- TRAINING ---------- #
    parser.add_argument("--optimizer", default="adam", type=str)
    parser.add_argument("--scheduler", default="onecycle", type=str)
    parser.add_argument("--baselr", default=0.0001, type=float)
    parser.add_argument("--epochs", default=30, type=int)
    parser.add_argument("--steps", default=15, type=int, help="decay LR in epochs")
    # ----------- INPUT ------------ #
    parser.add_argument(
        "--onehot",
        action="store_true",
        help="use onehot representation if true, otherwise use codes",
    )

    # ----------- ENCODER PARAMS ------------ #
    parser.add_argument("--d-model", default=64, type=int)
    parser.add_argument("--d_hid", default=128, type=int)
    parser.add_argument("--nhead", default=2, type=int)
    parser.add_argument("--nlayers", default=2, type=int)

    return parser.parse_args()


if __name__ == "__main__":
    args = default_argument_parser()
    train(args)
