import argparse
import json
import os

import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

from model_nlp import Transformer
from data import Dataset_NLP, DataLoader
from utils import append_positional_encoding, identity_pe, get_pe
from utils import get_nlp_loss as get_loss


def run_experiment(target, data, device, model_savepath=None, run_id=0):
    """Run experiment with given parameters and return final losses for standard and positional transformer

    Args:
        target (str): target function (one of 'sum', 'min', 'median', 'sort', 'minsum')
        data (dict): dictionary containing parameters for the experiment
        device (torch.device): device to run the experiment on
    """
    num_train_samples = data["num_train_samples"]
    num_test_samples = data["num_test_samples"]
    num_cats = data["num_cats"]
    num_query_cats = data["num_query_cats"] if "num_query_cats" in data else None
    num_additional_node = data["num_additional_node"]
    lr = data["lr"]
    batch_size = data["batch_size"]
    shuffling = data["shuffling"]
    low_train = data["low_train"]
    high_train = data["high_train"]
    low_test = data["low_test"]
    high_test = data["high_test"]

    epochs = data["epochs"]

    train_dataset = Dataset_NLP(
        num_samples=num_train_samples,
        num_cats=num_cats,
        low=low_train,
        high=high_train,
        target=target,
        num_additional_node=num_additional_node,
        num_query_cats=num_query_cats,
    )
    val_dataset = Dataset_NLP(
        num_samples=num_test_samples,
        num_cats=num_cats,
        low=low_train,
        high=high_train,
        target=target,
        num_additional_node=num_additional_node,
        num_query_cats=num_query_cats,
    )
    test_dataset = Dataset_NLP(
        num_samples=num_test_samples,
        num_cats=num_cats,
        low=low_test,
        high=high_test,
        target=target,
        num_additional_node=num_additional_node,
        num_query_cats=num_query_cats,
        reject_low=low_train,
        reject_high=high_train,
    )

    n = (
        train_dataset.list_size - num_additional_node
    )  # number of elements in the tokenized sequence excluding scratchpad
    pos_enc_base = identity_pe(n + num_additional_node).to(device)
    data_dim = 1

    in_dim_s = data_dim + pos_enc_base.size(1)
    in_dim_p = data_dim
    out_dim = data_dim
    embed_dim = data["embed_dim"]
    num_heads = data["num_heads"]
    num_layers = (
        np.log2(n).astype(int) + 1
        if "model_num_layers" not in data
        else data["model_num_layers"]
    )
    mlp_hidden_dim = data["mlp_hidden_dim"]
    mlp_num_layers = data["mlp_num_layers"]

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=shuffling, variable_length=False
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, variable_length=False
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, variable_length=False
    )

    model_s = Transformer(
        in_dim=in_dim_s,
        embed_dim=embed_dim,
        out_dim=out_dim,
        num_heads=num_heads,
        num_layers=num_layers,
        mlp_hidden_dim=mlp_hidden_dim,
        mlp_num_layers=mlp_num_layers,
        positional=False,
        pos_dim=pos_enc_base.size(1),
    ).to(device)
    model_p = Transformer(
        in_dim=in_dim_p,
        embed_dim=embed_dim,
        out_dim=out_dim,
        num_heads=num_heads,
        num_layers=num_layers,
        mlp_hidden_dim=mlp_hidden_dim,
        mlp_num_layers=mlp_num_layers,
        positional=True,
        pos_dim=pos_enc_base.size(1),
    ).to(device)
    optimizer_s = torch.optim.Adam(
        model_s.parameters(), lr=lr, weight_decay=data["weight_decay"]
    )
    optimizer_p = torch.optim.Adam(
        model_p.parameters(), lr=lr, weight_decay=data["weight_decay"]
    )
    scheduler_s = ReduceLROnPlateau(
        optimizer_s, mode="min", patience=10, factor=0.9, min_lr=1.0e-6
    )
    scheduler_p = ReduceLROnPlateau(
        optimizer_p, mode="min", patience=10, factor=0.9, min_lr=1.0e-6
    )
    criterion = nn.MSELoss()

    for epoch in range(epochs):

        model_s.train()
        model_p.train()
        train_loss_s = 0
        train_loss_p = 0
        loss_s = 0
        loss_p = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            x_app = append_positional_encoding(x, pos_enc_base)

            optimizer_s.zero_grad()
            out = model_s(x_app, p=pos_enc_base)
            loss_s = get_loss(criterion, out, y, num_additional_node, n, target)
            loss_s.backward()
            optimizer_s.step()
            train_loss_s += loss_s.item()

            optimizer_p.zero_grad()
            out = model_p(x, p=pos_enc_base)
            loss_p = get_loss(criterion, out, y, num_additional_node, n, target)
            loss_p.backward()
            optimizer_p.step()
            train_loss_p += loss_p.item()

        scheduler_s.step(train_loss_s)
        scheduler_p.step(train_loss_p)

        if epoch % 10 == 0:
            with torch.no_grad():
                val_loss_s, test_loss_s = 0, 0
                val_loss_p, test_loss_p = 0, 0
                for x, y in val_loader:
                    x, y = x.to(device), y.to(device)
                    x_app = append_positional_encoding(x, pos_enc_base)
                    out = model_s(x_app, p=pos_enc_base)
                    val_loss_s += get_loss(
                        criterion, out, y, num_additional_node, n, target
                    ).item()
                    out = model_p(x, p=pos_enc_base)
                    val_loss_p += get_loss(
                        criterion, out, y, num_additional_node, n, target
                    ).item()

                for x, y in test_loader:
                    x, y = x.to(device), y.to(device)
                    x_app = append_positional_encoding(x, pos_enc_base)
                    out = model_s(x_app, p=pos_enc_base)
                    test_loss_s += get_loss(
                        criterion, out, y, num_additional_node, n, target
                    ).item()
                    out = model_p(x, p=pos_enc_base)
                    test_loss_p += get_loss(
                        criterion, out, y, num_additional_node, n, target
                    ).item()

                train_loss_s /= len(train_loader)
                train_loss_p /= len(train_loader)
                val_loss_s /= len(val_loader)
                val_loss_p /= len(val_loader)
                print(
                    f"Epoch {epoch}, standard train/val: {train_loss_s:.4e}/{val_loss_s:.4e}, positional train/val: {train_loss_p:.4e}/{val_loss_p:.4e}/"
                )
                print(f"Test positional {test_loss_p:.4e}")
                print(f"Test standard: {test_loss_s:.4e}")

                if epoch == epochs - 1:
                    final_losses_s = [train_loss_s] + [val_loss_s] + [test_loss_s]
                    final_losses_p = [train_loss_p] + [val_loss_p] + [test_loss_p]

        if epoch % 100 == 0:
            print(
                "Learning rate for standard transformer: ",
                optimizer_s.param_groups[0]["lr"],
            )
            print(
                "Learning rate for positional transformer: ",
                optimizer_p.param_groups[0]["lr"],
            )

    if model_savepath is not None:
        torch.save(model_s.state_dict(), model_savepath + f"/run{run_id}_standard.pt")
        torch.save(model_p.state_dict(), model_savepath + f"/run{run_id}_positional.pt")

    return final_losses_s, final_losses_p


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--savepath", type=str, required=True)
    argparser.add_argument("--params", type=str, required=True)
    argparser.add_argument("--task", type=str, required=True)
    args = argparser.parse_args()

    print("PyTorch version:", torch.__version__)
    print("Access to GPU:", torch.cuda.is_available())
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    with open(args.params, "r") as fp:
        data = json.load(fp)

    os.makedirs(args.savepath, exist_ok=True)
    with open(args.savepath + "/params.json", "w") as fp:
        json.dump(data, fp)

    model_savepath = args.savepath + "/models"
    os.makedirs(model_savepath, exist_ok=True)

    print(f"Experiment: NLP - {args.task}")
    num_cats = data["num_cats"]
    num_query_cats = (
        data["num_query_cats"] if "num_query_cats" in data else "unspecified"
    )
    low_train = data["low_train"]
    high_train = data["high_train"]
    low_test = data["low_test"]
    high_test = data["high_test"]

    num_train_samples = data["num_train_samples"]
    filename = f"/train_val_test_scale_nlp_samples{num_train_samples}"
    filename_s = args.savepath + filename + "_standard.txt"
    filename_p = args.savepath + filename + "_positional.txt"
    os.makedirs(os.path.dirname(filename_s), exist_ok=True)
    os.makedirs(os.path.dirname(filename_p), exist_ok=True)

    print(
        f"Number of categories: {num_cats}, Number of categories in each query: {num_query_cats}"
    )
    print(
        f"Training samples: {num_train_samples}, Training range: [{low_train}, {high_train}]"
    )
    print(
        f"Testing samples: {data['num_test_samples']}, Testing range: [{low_test}, {high_test}]"
    )
    for run in range(data["runs"]):
        print(f"Run {run+1} / {data['runs']}:")
        final_losses_s, final_losses_p = run_experiment(
            args.task, data, device, model_savepath=model_savepath, run_id=run + 1
        )

        with open(filename_s, "a") as f:
            for loss in final_losses_s:
                print(f"{loss:.10e}\t", end="", file=f)
            print("", file=f)

        with open(filename_p, "a") as f:
            for loss in final_losses_p:
                print(f"{loss:.10e}\t", end="", file=f)
            print("", file=f)

    print("===============================================")
