import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import math
import pdb
import argparse
from lib.models.revin import RevIN
from lib.models.classif import SimpleMLP, EEGNet, SensorTimeEncoder
from lib.utils.checkpoint import EarlyStopping
from lib.utils.env import seed_all_rng


def create_dataloader(datapath, datapath2, batchsize=8):
    # train
    trainx_file = os.path.join(datapath, "train_x_original.npy")
    trainx = np.load(trainx_file)
    trainx = torch.from_numpy(trainx).to(dtype=torch.float32)

    trainy_file = os.path.join(datapath2, "train_labels.npy")
    trainy = np.load(trainy_file)
    trainy = torch.from_numpy(trainy).to(dtype=torch.float32)

    traincodes_file = os.path.join(datapath, "train_x_codes.npy")
    traincodes = np.load(traincodes_file)
    traincodes = torch.from_numpy(traincodes).to(dtype=torch.int64)

    # val
    valx_file = os.path.join(datapath, "val_x_original.npy")
    valx = np.load(valx_file)
    valx = torch.from_numpy(valx).to(dtype=torch.float32)

    valy_file = os.path.join(datapath2, "val_labels.npy")
    valy = np.load(valy_file)
    valy = torch.from_numpy(valy).to(dtype=torch.float32)

    valcodes_file = os.path.join(datapath, "val_x_codes.npy")
    valcodes = np.load(valcodes_file)
    valcodes = torch.from_numpy(valcodes).to(dtype=torch.int64)

    trainvalx = torch.cat((trainx, valx), dim=0)
    trainvaly = torch.cat((trainy, valy), dim=0)
    trainvalcodes = torch.cat((traincodes, valcodes), dim=0)
    num = trainvalx.shape[0]

    print(
        "[Train] %d examples: %.3f 0s - %.3f 1s"
        % (num, (trainvaly == 0).sum() / num, (trainvaly == 1).sum() / num)
    )

    train_dataset = torch.utils.data.TensorDataset(trainvalx, trainvalcodes, trainvaly)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batchsize,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )

    # test
    testx_file = os.path.join(datapath, "test_x_original.npy")
    testx = np.load(testx_file)
    testx = torch.from_numpy(testx).to(dtype=torch.float32)

    testy_file = os.path.join(datapath2, "test_labels.npy")
    testy = np.load(testy_file)
    testy = torch.from_numpy(testy).to(dtype=torch.float32)

    testcodes_file = os.path.join(datapath, "test_x_codes.npy")
    testcodes = np.load(testcodes_file)
    testcodes = torch.from_numpy(testcodes).to(dtype=torch.int64)

    num = testx.shape[0]

    print(
        "[Test] %d examples: %.3f 0s - %.3f 1s"
        % (num, (testy == 0).sum() / num, (testy == 1).sum() / num)
    )

    test_dataset = torch.utils.data.TensorDataset(testx, testcodes, testy)
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batchsize,
        shuffle=True,
        num_workers=1,
        drop_last=True,
    )

    return train_dataloader, test_dataloader


def train_one_epoch(
    dataloader,
    model,
    codebook,
    optimizer,
    scheduler,
    epoch,
    device,
):
    running_loss, last_loss = 0.0, 0.0
    running_acc, last_acc = 0.0, 0.0
    log_every = max(len(dataloader) // 3, 3)

    for i, data in enumerate(dataloader):
        # ----- LOAD DATA ------ #
        x, code_ids, y = data
        # x: (B, T, S)
        # codes: (B, TC, S)  TC = T // compression
        # y: (B,)
        x = x.to(device)
        code_ids = code_ids.to(device)
        y = y.to(device)

        # reshape data
        B, T, S = x.shape
        B, TC, S = code_ids.shape

        # get codewords for input x
        code_ids = code_ids.flatten()
        xcodes = codebook[code_ids]  # (B*TC*S, D)
        xcodes = xcodes.reshape((B, TC, S, xcodes.shape[-1]))  # (B, TC, S, D)

        # revin time series
        norm_x = model.revin(x, "norm")

        if isinstance(model, SimpleMLP):
            x = x.flatten(start_dim=1)
            predy = model(x)
        elif isinstance(model, EEGNet):
            x = torch.permute(x, (0, 2, 1))  # (B, S, T)
            x = x.unsqueeze(1)  # (B, 1, S, T)
            predy = model(x)
        elif isinstance(model, SensorTimeEncoder):
            scale = torch.cat((model.revin.mean, model.revin.stdev), dim=1)
            scale = torch.permute(scale, (0, 2, 1))
            predy = model(xcodes, scale)
        else:
            raise ValueError("womp womp")

        loss = F.binary_cross_entropy_with_logits(predy, y)

        # optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log
        running_loss += loss.item()
        with torch.no_grad():
            running_acc += ((predy.sigmoid() > 0.5) == y).float().mean()
        if i % log_every == log_every - 1:
            last_loss = running_loss / log_every  # loss per batch
            last_acc = running_acc / log_every
            lr = optimizer.param_groups[0]["lr"]
            # lr = scheduler.get_last_lr()[0]
            print(
                f"| epoch {epoch:3d} | {i+1:5d}/{len(dataloader):5d} batches | "
                f"lr {lr:02.5f} | loss {last_loss:5.3f} | acc {last_acc:5.3f}"
            )
            running_loss = 0.0
            running_acc = 0.0

        if scheduler is not None:
            scheduler.step()


def inference(
    data,
    model,
    codebook,
    device,
):
    x, code_ids, _ = data
    x = x.to(device)
    code_ids = code_ids.to(device)

    # reshape data
    B, T, S = x.shape
    B, TC, S = code_ids.shape

    # get codewords for input x
    code_ids = code_ids.flatten()
    xcodes = codebook[code_ids]  # (B*TC*S, D)
    xcodes = xcodes.reshape((B, TC, S, xcodes.shape[-1]))  # (B, TC, S, D)

    # revin time series
    norm_x = model.revin(x, "norm")

    if isinstance(model, SimpleMLP):
        x = x.flatten(start_dim=1)
        predy = model(x)
    elif isinstance(model, EEGNet):
        x = torch.permute(x, (0, 2, 1))  # (B, S, T)
        x = x.unsqueeze(1)  # (B, 1, S, T)
        predy = model(x)
    elif isinstance(model, SensorTimeEncoder):
        scale = torch.cat((model.revin.mean, model.revin.stdev), dim=1)
        scale = torch.permute(scale, (0, 2, 1))
        predy = model(xcodes, scale)
    else:
        raise ValueError("wamp wamp")

    return predy


def train(args):
    device = torch.device("cuda:%d" % (args.cuda_id))
    torch.cuda.set_device(device)

    # -------- SET SEED ------- #
    seed_all_rng(None if args.seed < 0 else args.seed)

    # -------- PARAMS ------- #
    params = get_params(args.data_type)
    batchsize = params["batchsize"]
    datapath = params["dataroot"]
    datapath2 = params["dataroot2"]
    expname = params["expname"]
    S = params["S"]

    # -------- CHECKPOINT ------- #
    checkpath = None
    if args.checkpoint:
        checkpath = os.path.join(args.checkpoint_path, expname)
        os.makedirs(checkpath, exist_ok=True)

    # ------ DATA LOADERS ------- #
    dataloaders = create_dataloader(
        datapath=datapath, datapath2=datapath2, batchsize=batchsize
    )
    train_dataloader = dataloaders[0]
    test_dataloader = dataloaders[1]

    # -------- CODEBOOK ------- #
    codebook = np.load(os.path.join(datapath, "codebook.npy"), allow_pickle=True)
    codebook = torch.from_numpy(codebook).to(device=device, dtype=torch.float32)
    vocab_size, vocab_dim = codebook.shape
    assert vocab_size == args.codebook_size
    dim = vocab_size if args.onehot else vocab_dim

    # ------- MODEL -------- #
    if args.model_type == "mlp":
        # time --> class (baseline)
        model = SimpleMLP(
            in_dim=S * args.Tin, out_dim=1, hidden_dims=[1024, 512, 256], dropout=0.0
        )
    elif args.model_type == "eeg":
        # time --> class (baseline)
        model = EEGNet(
            chunk_size=args.Tin,
            num_electrodes=S,
            F1=8,
            F2=16,
            D=2,
            kernel_1=64,
            kernel_2=16,
            dropout=0.25,
        )
    elif args.model_type == "xformer":
        # code --> class (ours)
        model = SensorTimeEncoder(
            d_in=dim,
            d_model=args.d_model,
            nheadt=args.nhead,
            nheads=args.nhead,
            d_hid=args.d_hid,
            nlayerst=args.nlayers,
            nlayerss=args.nlayers,
            seq_lent=args.Tin // args.compression,
            seq_lens=S,
            dropout=0.25,
        )
    else:
        raise ValueError("Unknown model type %s" % (arts.model_type))
    model.revin = RevIN(num_features=S, affine=False)  # expects as input (B, T, S)
    model.to(device)

    # ------- OPTIMIZER -------- #
    num_iters = args.epochs * len(train_dataloader)
    step_lr_in_iters = args.steps * len(train_dataloader)
    model_params = list(model.parameters())
    if args.optimizer == "sgd":
        optimizer = torch.optim.SGD(model_params, lr=args.baselr, momentum=0.9)
    elif args.optimizer == "adam":
        optimizer = torch.optim.Adam(model_params, lr=args.baselr)
    elif args.optimizer == "adamw":
        optimizer = torch.optim.AdamW(model_params, lr=args.baselr)
    else:
        raise ValueError("Uknown optimizer type %s" % (args.optimizer))
    if args.scheduler == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_lr_in_iters, gamma=0.1
        )
    elif args.scheduler == "onecycle":
        # The learning rate will increate from max_lr / div_factor to max_lr in the first pct_start * total_steps steps,
        # and decrease smoothly to max_lr / final_div_factor then.
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=optimizer,
            max_lr=args.baselr,
            steps_per_epoch=len(train_dataloader),
            epochs=args.epochs,
            pct_start=0.2,
        )
    else:
        raise ValueError("Uknown scheduler type %s" % (args.scheduler))

    # ------- TRAIN & EVAL -------- #
    for epoch in range(args.epochs):
        model.train()
        train_one_epoch(
            train_dataloader,
            model,
            codebook,
            optimizer,
            scheduler,
            epoch,
            device,
        )

        if test_dataloader is not None:
            model.eval()
            running_acc = 0.0
            total_num = 0.0
            # Disable gradient computation and reduce memory consumption.
            with torch.no_grad():
                for i, tdata in enumerate(test_dataloader):
                    pred = inference(
                        tdata,
                        model,
                        codebook,
                        device,
                    )
                    y = tdata[-1]
                    y = y.to(device)

                    running_acc += ((pred.sigmoid() > 0.5) == y).sum()
                    total_num += y.numel()
            running_acc = running_acc / total_num
            print(f"| [Test] acc {running_acc:5.3f}")


def get_params(data_type):
    if data_type == "neuro":
        expname = "%s_pt%d_Tin%d" % (args.data_type, args.pt, args.Tin)
        if args.pt == 2:
            dataroot2 = os.path.join(
                "/data/dummy/mts_v2_datasets/pipeline/neuro_tr_v_test_before_revin",
                "pt%d" % (args.pt),
            )
            dataroot = os.path.join(
                "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/pt%d/BS_2048"
                % (args.pt)
            )
            batchsize = 16
            S = 72
        elif args.pt == 5:
            dataroot2 = os.path.join(
                "/data/dummy/mts_v2_datasets/pipeline/neuro_tr_v_test_before_revin",
                "pt%d" % (args.pt),
            )
            dataroot = os.path.join(
                "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/pt%d/BS_4096"
                % (args.pt)
            )
            batchsize = 16
            S = 106
        elif args.pt == 12:
            dataroot2 = os.path.join(
                "/data/dummy/mts_v2_datasets/pipeline/neuro_tr_v_test_before_revin",
                "pt%d" % (args.pt),
            )
            dataroot = os.path.join(
                "/data/dummy/mts_v2_datasets/pipeline/vqvae_coded_data_classification/pt%d/BS_4096"
                % (args.pt)
            )
            batchsize = 32
            S = 93
        else:
            raise ValueError("No support for patient %d" % (args.pt))
    else:
        raise ValueError("No support for data type %d" % (args.data_type))

    return {
        "dataroot": dataroot,
        "dataroot2": dataroot2,
        "batchsize": batchsize,
        "S": S,
        "expname": expname,
    }


def default_argument_parser():
    """
    Create a parser.

    Returns:
        argparse.ArgumentParser:
    """
    parser = argparse.ArgumentParser(description="Classification")
    parser.add_argument(
        "--resume",
        action="store_true",
        help="whether to attempt to resume from the checkpoint directory",
    )
    parser.add_argument("--cuda-id", default=0, type=int)
    # ---------- SEED ---------- #
    # Set seed to negative to fully randomize everything.
    # Set seed to positive to use a fixed seed. Note that a fixed seed increases
    # reproducibility but does not guarantee fully deterministic behavior.
    parser.add_argument("--seed", default=-1, type=int)
    # ---------- MODEL ---------- #
    parser.add_argument(
        "--model-type",
        default="eeg",
        type=str,
        help="The model, either 'eeg' or 'xformer'.",
    )
    # ---------- DATA ---------- #
    parser.add_argument("--data-type", default="neuro", type=str)
    parser.add_argument("--pt", default=-1, type=int)
    parser.add_argument("--codebook_size", default=256, type=int)
    parser.add_argument("--compression", default=4, type=int)
    parser.add_argument("--Tin", default=1001, type=int)
    # ----------- CHECKPOINT ------------ #
    parser.add_argument("--checkpoint", default=0, type=bool)
    parser.add_argument(
        "--checkpoint_path", default="/data/dummy/experiments", type=str
    )
    parser.add_argument("--patience", default=7, type=int)
    # ---------- TRAINING ---------- #
    parser.add_argument("--optimizer", default="adam", type=str)
    parser.add_argument("--scheduler", default="onecycle", type=str)
    parser.add_argument("--baselr", default=0.0001, type=float)
    parser.add_argument("--epochs", default=30, type=int)
    parser.add_argument("--steps", default=15, type=int, help="decay LR in epochs")
    # ----------- INPUT ------------ #
    parser.add_argument(
        "--onehot",
        action="store_true",
        help="use onehot representation if true, otherwise use codes",
    )

    # ----------- ENCODER PARAMS ------------ #
    parser.add_argument("--d-model", default=64, type=int)
    parser.add_argument("--d_hid", default=128, type=int)
    parser.add_argument("--nhead", default=2, type=int)
    parser.add_argument("--nlayers", default=2, type=int)

    return parser.parse_args()


if __name__ == "__main__":
    args = default_argument_parser()
    train(args)
