import argparse
import json
import os
import warnings

import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

from model import Transformer
from data import Dataset_Basic, DataLoader
from utils import (
    append_positional_encoding,
    identity_pe,
    binary_pe,
    sinusoidal_pe,
    get_pe,
    get_loss,
)


def run_experiment(target, data, device, model_savepath=None, run_id=0):
    """Run experiment with given parameters and return final losses for standard and positional transformer

    Args:
        target (str): target function (one of 'sum', 'min', 'median', 'sort', 'minsum')
        data (dict): dictionary containing parameters for the experiment
        device (torch.device): device to run the experiment on
    """
    n = data["n"][0]
    num_train_samples = data["num_train_samples"][0]
    num_test_samples = data["num_test_samples"]
    num_additional_node = data["num_additional_node"]
    lr = data["lr"]
    batch_size = data["batch_size"]
    shuffling = data["shuffling"]
    low_train = data["low_train"]
    high_train = data["high_train"]
    cumulative = data["cumulative"]
    use_integer = data["use_integer"]
    variable_length = data["variable_length"] if "variable_length" in data else False
    pos_enc_type = data["pos_enc_type"] if "pos_enc_type" in data else "identity"

    if pos_enc_type == "identity":
        pos_enc_gen = identity_pe
    elif pos_enc_type == "sinusoidal":
        pos_enc_gen = lambda x: sinusoidal_pe(x, dim=min(n // 2, 2))
    elif pos_enc_type == "binary":
        pos_enc_gen = binary_pe
    else:
        pos_enc_gen = identity_pe
        warnings.warn(
            f"Unknown positional encoding type: {pos_enc_type}. Using identity positional encoding."
        )

    pos_enc_base = pos_enc_gen(n + num_additional_node).to(device)

    data_dim = 1
    in_dim_s = data_dim + pos_enc_base.size(1)
    in_dim_p = data_dim
    out_dim = data_dim
    embed_dim = data["embed_dim"]
    num_heads = data["num_heads"]
    use_rope = data["RoPE"] if "RoPE" in data else False
    num_layers = (
        np.log2(n).astype(int) + 1
        if "model_num_layers" not in data
        else data["model_num_layers"]
    )
    mlp_hidden_dim = data["mlp_hidden_dim"]
    mlp_num_layers = data["mlp_num_layers"]

    epochs = data["epochs"]

    train_dataset = Dataset_Basic(
        num_samples=num_train_samples,
        length=n,
        low=low_train,
        high=high_train,
        target=target,
        use_integer=use_integer,
        cumulative=cumulative,
        num_additional_node=num_additional_node,
        variable_length=variable_length,
    )
    val_dataset = Dataset_Basic(
        num_samples=num_test_samples,
        length=n,
        low=low_train,
        high=high_train,
        target=target,
        use_integer=use_integer,
        cumulative=cumulative,
        num_additional_node=num_additional_node,
        variable_length=variable_length,
    )
    test_datasets = [
        Dataset_Basic(
            num_samples=num_test_samples,
            length=n,
            low=low_test,
            high=high_test,
            target=target,
            use_integer=use_integer,
            cumulative=cumulative,
            num_additional_node=num_additional_node,
            reject_low=low_train,
            reject_high=high_train,
            variable_length=variable_length,
        )
        for low_test, high_test in zip(data["low_test"], data["high_test"])
    ]

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=shuffling,
        variable_length=variable_length,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        variable_length=variable_length,
    )
    test_loaders = [
        DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            variable_length=variable_length,
        )
        for test_dataset in test_datasets
    ]

    model_s = Transformer(
        in_dim=in_dim_s,
        embed_dim=embed_dim,
        out_dim=out_dim,
        num_heads=num_heads,
        num_layers=num_layers,
        mlp_hidden_dim=mlp_hidden_dim,
        mlp_num_layers=mlp_num_layers,
        positional=False,
        RoPE=use_rope,
        pos_dim=pos_enc_base.size(1),
    ).to(device)
    model_p = Transformer(
        in_dim=in_dim_p,
        embed_dim=embed_dim,
        out_dim=out_dim,
        num_heads=num_heads,
        num_layers=num_layers,
        mlp_hidden_dim=mlp_hidden_dim,
        mlp_num_layers=mlp_num_layers,
        positional=True,
        pos_dim=pos_enc_base.size(1),
    ).to(device)
    optimizer_s = torch.optim.Adam(
        model_s.parameters(), lr=lr, weight_decay=data["weight_decay"]
    )
    optimizer_p = torch.optim.Adam(
        model_p.parameters(), lr=lr, weight_decay=data["weight_decay"]
    )
    scheduler_s = ReduceLROnPlateau(
        optimizer_s, mode="min", patience=50, factor=0.9, min_lr=1.0e-6
    )
    scheduler_p = ReduceLROnPlateau(
        optimizer_p, mode="min", patience=50, factor=0.9, min_lr=1.0e-6
    )
    criterion = nn.MSELoss()

    for epoch in range(epochs):

        model_s.train()
        model_p.train()
        train_loss_s = 0
        train_loss_p = 0
        loss_s = 0
        loss_p = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            pos_enc = (
                get_pe(pos_enc_base, x, num_additional_node)
                if variable_length
                else pos_enc_base
            )
            x_app = append_positional_encoding(x, pos_enc)

            optimizer_s.zero_grad()
            out = model_s(x_app, p=pos_enc)
            loss_s = get_loss(criterion, out, y, num_additional_node, n, target)
            loss_s.backward()
            optimizer_s.step()
            train_loss_s += loss_s.item()

            optimizer_p.zero_grad()
            out = model_p(x, p=pos_enc)
            loss_p = get_loss(criterion, out, y, num_additional_node, n, target)
            loss_p.backward()
            optimizer_p.step()
            train_loss_p += loss_p.item()

        scheduler_s.step(train_loss_s)
        scheduler_p.step(train_loss_p)

        if epoch % 10 == 0:
            with torch.no_grad():
                val_loss_s, test_loss_s = 0, [0] * len(test_loaders)
                val_loss_p, test_loss_p = 0, [0] * len(test_loaders)
                for x, y in val_loader:
                    x, y = x.to(device), y.to(device)
                    pos_enc = (
                        get_pe(pos_enc_base, x, num_additional_node)
                        if variable_length
                        else pos_enc_base
                    )
                    x_app = append_positional_encoding(x, pos_enc)
                    out = model_s(x_app, p=pos_enc)
                    val_loss_s += get_loss(
                        criterion, out, y, num_additional_node, n, target
                    ).item()
                    out = model_p(x, p=pos_enc)
                    val_loss_p += get_loss(
                        criterion, out, y, num_additional_node, n, target
                    ).item()

                for i, test_loader in enumerate(test_loaders):
                    for x, y in test_loader:
                        x, y = x.to(device), y.to(device)
                        pos_enc = (
                            get_pe(pos_enc_base, x, num_additional_node)
                            if variable_length
                            else pos_enc_base
                        )
                        x_app = append_positional_encoding(x, pos_enc)
                        out = model_s(x_app, p=pos_enc)
                        test_loss_s[i] += get_loss(
                            criterion, out, y, num_additional_node, n, target
                        ).item()
                        out = model_p(x, p=pos_enc)
                        test_loss_p[i] += get_loss(
                            criterion, out, y, num_additional_node, n, target
                        ).item()

                train_loss_s /= len(train_loader)
                train_loss_p /= len(train_loader)
                val_loss_s /= len(val_loader)
                val_loss_p /= len(val_loader)
                print(
                    f"Epoch {epoch}, standard train/val: {train_loss_s:.4e}/{val_loss_s:.4e}, positional train/val: {train_loss_p:.4e}/{val_loss_p:.4e}/"
                )
                str_loss_p = "/".join([f"{loss:.4e}" for loss in test_loss_p])
                str_loss_s = "/".join([f"{loss:.4e}" for loss in test_loss_s])
                print(f"Test positional: {str_loss_p}")
                print(f"Test standard: {str_loss_s}")

                if epoch == epochs - 1:
                    final_losses_s = [train_loss_s] + [val_loss_s] + test_loss_s
                    final_losses_p = [train_loss_p] + [val_loss_p] + test_loss_p

        if epoch % 100 == 0:
            print(
                "Learning rate for standard transformer: ",
                optimizer_s.param_groups[0]["lr"],
            )
            print(
                "Learning rate for positional transformer: ",
                optimizer_p.param_groups[0]["lr"],
            )

    if model_savepath is not None:
        torch.save(model_s.state_dict(), model_savepath + f"/run{run_id}_standard.pt")
        torch.save(model_p.state_dict(), model_savepath + f"/run{run_id}_positional.pt")

    return final_losses_s, final_losses_p


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--savepath", type=str, required=True)
    argparser.add_argument("--params", type=str, required=True)
    argparser.add_argument("--task", type=str, required=True)
    args = argparser.parse_args()

    print("PyTorch version:", torch.__version__)
    print("Access to GPU:", torch.cuda.is_available())
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    with open(args.params, "r") as fp:
        data = json.load(fp)

    os.makedirs(args.savepath, exist_ok=True)
    with open(args.savepath + "/params.json", "w") as fp:
        json.dump(data, fp)

    model_savepath = args.savepath + "/models"
    os.makedirs(model_savepath, exist_ok=True)

    assert len(data["n"]) == 1, "Only one value of n should be provided"
    assert (
        len(data["num_train_samples"]) == 1
    ), "Only one value of num_train_samples should be provided"
    assert len(data["low_test"]) == len(
        data["high_test"]
    ), "Length of low_test and high_test should be the same"
    experiment = "scale_generalization"
    times = len(data["low_test"])
    variable_length = data["variable_length"] if "variable_length" in data else False
    use_rope = data["RoPE"] if "RoPE" in data else False

    print(f"Experiment: {experiment}")
    print(f"Task: {args.task}")
    print(f"Variable length: {variable_length}")
    print(f"Using RoPE: {use_rope}")
    n = data["n"][0]
    num_train_samples = data["num_train_samples"][0]
    low_test = (
        data["low_test"][0]
        if experiment == "scale_generalization"
        else data["low_test"][0]
    )
    high_test = (
        data["high_test"][0]
        if experiment == "scale_generalization"
        else data["high_test"][0]
    )

    filename = f"/train_val_test_scale_n{n}_samples{num_train_samples}"
    filename_s = args.savepath + filename + "_standard.txt"
    filename_p = args.savepath + filename + "_positional.txt"
    os.makedirs(os.path.dirname(filename_s), exist_ok=True)
    os.makedirs(os.path.dirname(filename_p), exist_ok=True)

    print(
        f"n: {n}, Training samples: {num_train_samples}, Test range: [{low_test}, {high_test}]"
    )
    for run in range(data["runs"]):
        print(f"Run {run+1} / {data['runs']}:")
        final_losses_s, final_losses_p = run_experiment(
            args.task, data, device, model_savepath=model_savepath, run_id=run + 1
        )

        with open(filename_s, "a") as f:
            for loss in final_losses_s:
                print(f"{loss:.10e}\t", end="", file=f)
            print("", file=f)

        with open(filename_p, "a") as f:
            for loss in final_losses_p:
                print(f"{loss:.10e}\t", end="", file=f)
            print("", file=f)

    print("===============================================")
