import time
import copy
import pickle
import scipy
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import argparse
from transformers import AutoTokenizer
from datasets import load_dataset

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--offset', type=int, default=1, help='Offset to distinguish a saved file')
    parser.add_argument('--alg', type=int, default=0, help='Algorithm type')
    args = parser.parse_args()
    return args
args = parse_args()

# configuration
save_file = 0

algs = ["Lazy", "AmIGO", "SOBA"]
alg_flag = args.alg # algorithms
m_flag = 0 # momentum flag
lazy_flag = 1 # lazy JVP

num_run = 20
num_Ts = [35000, 3300, 11000]
num_T = num_Ts[alg_flag]

# learning parameters
batch_size = 256
alphas = [0.00002, 0.00002, 0.00002] # for x (weighting)
betas  = [0.0000003, 0.0000003, 0.0000003] # for y (MLP)
gammas = [0.0000002, 0.0000002, 0.0000002] # for z (auxiliary)

mu = 0.8 # momentum
lazy_N = 5
amigo_M = 5
amigo_N = 5

if m_flag == 0:
    mu = 0.0

if alg_flag != 0:
    lazy_N = 0

# x_net
num_DS = 2

# y_net
in_dim = 500
width = 500
depth = 5
out_dim = 1

if alg_flag != 0: # only SO-Lazy has lazy JVP
    lazy_flag = 0

if alg_flag == 1: # AmIGO does not have momentum
    m_flag = 0

task = "RM"
M_str = ["", "-M"]
L_str = ["", "-L"]
file_name = task + "_" + algs[alg_flag] + L_str[lazy_flag] + M_str[m_flag] + f"_Run{num_run}_T{num_T}_DS{num_DS}_mu{int(mu*10)}_N{lazy_N}"
print(file_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def build_y(input_dim: int, hidden_width: int, hidden_depth: int, output_dim: int):
    layers = []
    layers.append(nn.Linear(input_dim, hidden_width, bias=True))
    layers.append(nn.ReLU())
    for _ in range(hidden_depth):
        layers.append(nn.Linear(hidden_width, hidden_width, bias=True))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(hidden_width, output_dim, bias=False))
    return nn.Sequential(*layers)

def build_x(dim: int):
    layers = []
    layers.append(nn.Linear(1, dim, bias=False))
    layers.append(nn.Softmax(dim=1))
    return nn.Sequential(*layers)

def get_g_loss(y_opt, x_opt, y_model, x_model, loss_model, inputs_y, labels_y, one_tensor, mask_tensor):
    y_opt.zero_grad()
    x_opt.zero_grad()
    outputs_y = y_model(inputs_y)
    outputs_x = x_model(one_tensor, mask_tensor)
    g_loss = loss_model(outputs_y, labels_y, outputs_x)
    return g_loss

def get_f_loss(y_opt, x_opt, y_model, loss_model, inputs_y, labels_y):
    y_opt.zero_grad()
    x_opt.zero_grad()
    outputs_y = y_model(inputs_y)
    f_loss = loss_model(outputs_y, labels_y)
    return f_loss

def sort_data(dataset, max_size):
    idx = []
    prompt = []
    response =  []
    for i, a in enumerate(dataset):
        if len(a['prompt']) + len(a['response']) < max_size:
            idx.append(i)
            prompt.append(a['prompt'])
            response.append(a['response'])
    return idx, prompt, response

def get_scores(dataset):
    data = []
    for a in dataset.features:
        data.append(dataset[a])
    return np.array(data[2:])

def combine_data(prompt, response, Tokenizer, max_len):
    data_H_len = []
    data_H = []
    for a in range(len(prompt)):
        input_ids = Tokenizer(prompt[a], response[a]).input_ids
        data_H_len.append(len(input_ids))
        data_H.append(np.concatenate((input_ids, np.zeros([max_len-len(input_ids),])), axis=0))
    return np.array(data_H), np.array(data_H_len)


class y_model(nn.Module):
    def __init__(self, input_dim, hidden_width, hidden_depth, output_dim):
        super().__init__()
        self.y_net = build_y(input_dim, hidden_width, hidden_depth, output_dim)

    def forward(self, feature):
        output = self.y_net(feature)
        return output


class x_model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.x_net = build_x(dim)

    def forward(self, one, idx):
        weights = self.x_net(one)
        return torch.gather(weights, 1, idx)


class WeightedMSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs, targets, weights):
        return torch.mean(weights * (inputs - targets)**2)


Tokenizer_name = "microsoft/deberta-v3-large"
Tokenizer_D = AutoTokenizer.from_pretrained(Tokenizer_name)

ds_tra = load_dataset("nvidia/HelpSteer")['train']
ds_val = load_dataset("nvidia/HelpSteer")['validation']

idx_tra, prompt_tra, response_tra = sort_data(ds_tra, 1000)
idx_val, prompt_val, response_val = sort_data(ds_val, 1000)

scores_tra = get_scores(ds_tra)[:, np.array(idx_tra)].T
scores_val = get_scores(ds_val)[:, np.array(idx_val)].T

train_data_H, train_data_H_len = combine_data(prompt_tra, response_tra, Tokenizer_D, 500)
test_data_H, test_data_H_len = combine_data(prompt_val, response_val, Tokenizer_D, 500)

# train data
train_data = np.concatenate((train_data_H, train_data_H), axis=0)
train_label = np.concatenate((scores_tra[:, 2], scores_tra[:, 4]), axis=0)[:, np.newaxis]

test_data = train_data_H
test_label = scores_tra[:, 2][:, np.newaxis]

# dataset numbering
dataset_num = np.zeros([np.shape(train_data)[0], 1])
dataset_num[6524:] = 1

# tensor
train_data = torch.as_tensor(train_data).float().to(device)
train_label = torch.as_tensor(train_label).float().to(device)
test_data = torch.as_tensor(test_data).float().to(device)
test_label = torch.as_tensor(test_label).float().to(device)
dataset_num = torch.as_tensor(dataset_num).long().to(device)
one_tensor = torch.as_tensor(np.ones([batch_size, 1])).float().to(device) # input to the weighting_model

train_loss_run = []
test_loss_run = []
soft_run = []
times_run = []

for run in range(num_run):
    print(f"Run {run+1} / {num_run}")
    main_model = y_model(in_dim, width, depth, out_dim).to(device)
    weighting_model = x_model(num_DS).to(device)
    weighting_model.state_dict()['x_net.0.weight'].fill_(0)
    # if run == 0:
    #     print(main_model)
    #     print(weighting_model)

    z_params = []
    for param in main_model.parameters():
        z_params.append(torch.rand_like(param))

    y_optimizer = torch.optim.Adam(main_model.parameters())
    x_optimizer = torch.optim.Adam(weighting_model.parameters())
    f_loss_MSE  = torch.nn.MSELoss()
    g_loss_WMSE = WeightedMSELoss()

    train_loss_t = []
    test_loss_t = []
    soft_t = []
    times_t = []

    for param in weighting_model.parameters():
        x_param = param.tolist()
        soft_t.append(np.squeeze(scipy.special.softmax(x_param)))

    for t in range(num_T):
        batch_idx_train = np.random.choice(len(train_data), batch_size, replace=False)
        inputs_train = train_data[batch_idx_train]
        labels_train = train_label[batch_idx_train]
        train_mask_tensor = dataset_num[batch_idx_train]

        batch_idx_test = np.random.choice(len(test_data), batch_size, replace=False)
        inputs_test = test_data[batch_idx_test]
        labels_test = test_label[batch_idx_test]

        start_time = time.time()
        if alg_flag == 0:
            if (t % lazy_N) == 0:
                # compute g_yy*z (first term of h_q)
                g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                    g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
                grad_g_y = torch.autograd.grad(g_loss, main_model.parameters(), create_graph=True)
                HVP = torch.autograd.grad(grad_g_y, main_model.parameters(), grad_outputs=z_params)

                # compute f_y (second term of h_q)
                f_loss = get_f_loss(y_optimizer, x_optimizer, main_model, f_loss_MSE, inputs_test, labels_test)
                grad_f_y = torch.autograd.grad(f_loss, main_model.parameters())

                with torch.no_grad():
                    # compute h_q
                    h_q = []
                    for a in range(len(grad_f_y)):
                        h_q.append(HVP[a] + grad_f_y[a])

                    # update z
                    for a in range(len(z_params)):
                        z_params[a] = z_params[a] - gammas[alg_flag] * h_q[a]

                if lazy_flag == 1:
                    # update v (compute g_xy*z)
                    g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                        g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
                    grad_g_y = torch.autograd.grad(g_loss, main_model.parameters(), create_graph=True)
                    JVP = torch.autograd.grad(grad_g_y, weighting_model.parameters(), grad_outputs=z_params)

            if lazy_flag == 0:
                # update v (compute g_xy*z)
                g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                    g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
                grad_g_y = torch.autograd.grad(g_loss, main_model.parameters(), create_graph=True)
                JVP = torch.autograd.grad(grad_g_y, weighting_model.parameters(), grad_outputs=z_params)

            # compute h_g
            g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
            h_g = torch.autograd.grad(g_loss, main_model.parameters())

            # update y
            with torch.no_grad():
                for a, param in enumerate(main_model.parameters()):
                    param.data = param.data - betas[alg_flag] * h_g[a]

            with torch.no_grad():
                # compute h_f
                h_f = []
                for a, _ in enumerate(weighting_model.parameters()):
                    # h_f.append(grad_f_x[a] + JVP[a])
                    h_f.append(JVP[a])

                if m_flag == 1:
                    if t == 0:
                        bar_h_f = []
                        for a, _ in enumerate(weighting_model.parameters()):
                            bar_h_f.append(h_f[a])
                    else:
                        new_bar_h_f = []
                        for a, _ in enumerate(weighting_model.parameters()):
                            new_bar_h_f.append(mu*h_f[a] + (1-mu)*bar_h_f[a])
                        bar_h_f = copy.deepcopy(new_bar_h_f)

                # update x
                for a, param in enumerate(weighting_model.parameters()):
                    if m_flag == 0:
                        param.data = param.data - alphas[alg_flag] * h_f[a]
                    elif m_flag == 1:
                        param.data = param.data - alphas[alg_flag] * bar_h_f[a]

        elif alg_flag == 1:
            # compute f_y (second term of h_q)
            f_loss = get_f_loss(y_optimizer, x_optimizer, main_model, f_loss_MSE, inputs_test, labels_test)
            grad_f_y = torch.autograd.grad(f_loss, main_model.parameters())

            for a in range(amigo_M):
                # compute g_yy*z (first term of h_q)
                g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                    g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
                grad_g_y = torch.autograd.grad(g_loss, main_model.parameters(), create_graph=True)
                HVP = torch.autograd.grad(grad_g_y, main_model.parameters(), grad_outputs=z_params)

                with torch.no_grad():
                    # compute h_q
                    h_q = []
                    for a in range(len(grad_f_y)):
                        h_q.append(HVP[a] + grad_f_y[a])
                    
                    # update z
                    for a in range(len(z_params)):
                        z_params[a] = z_params[a] - gammas[alg_flag] * h_q[a]

            for a in range(amigo_N):
                # compute h_g
                g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                    g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
                h_g = torch.autograd.grad(g_loss, main_model.parameters())

                # update y
                with torch.no_grad():
                    for a, param in enumerate(main_model.parameters()):
                        param.data = param.data - betas[alg_flag] * h_g[a]

            # update v (compute g_xy*z)
            g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
            grad_g_y = torch.autograd.grad(g_loss, main_model.parameters(), create_graph=True)
            JVP = torch.autograd.grad(grad_g_y, weighting_model.parameters(), grad_outputs=z_params)

            with torch.no_grad():
                # compute h_f
                h_f = []
                for a, _ in enumerate(weighting_model.parameters()):
                    # h_f.append(grad_f_x[a] + JVP[a])
                    h_f.append(JVP[a])

                # update x
                for a, param in enumerate(weighting_model.parameters()):
                    param.data = param.data - alphas[alg_flag] * h_f[a]

        elif alg_flag == 2:
            # compute g_yy*z (first term of h_q)
            g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
            grad_g_y = torch.autograd.grad(g_loss, main_model.parameters(), create_graph=True)
            HVP = torch.autograd.grad(grad_g_y, main_model.parameters(), grad_outputs=z_params)

            # compute f_y (second term of h_q)
            f_loss = get_f_loss(y_optimizer, x_optimizer, main_model, f_loss_MSE, inputs_test, labels_test)
            grad_f_y = torch.autograd.grad(f_loss, main_model.parameters())

            with torch.no_grad():
                # compute h_q
                h_q = []
                for a in range(len(grad_f_y)):
                    h_q.append(HVP[a] + grad_f_y[a])

                # update z
                for a in range(len(z_params)):
                    z_params[a] = z_params[a] - gammas[alg_flag] * h_q[a]

            # compute h_g
            g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
            h_g = torch.autograd.grad(g_loss, main_model.parameters())

            # update y
            with torch.no_grad():
                for a, param in enumerate(main_model.parameters()):
                    param.data = param.data - betas[alg_flag] * h_g[a]

            # update v (compute g_xy*z)
            g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)
            grad_g_y = torch.autograd.grad(g_loss, main_model.parameters(), create_graph=True)
            JVP = torch.autograd.grad(grad_g_y, weighting_model.parameters(), grad_outputs=z_params)

            with torch.no_grad():
                # compute h_f
                h_f = []
                for a, _ in enumerate(weighting_model.parameters()):
                    # h_f.append(grad_f_x[a] + JVP[a])
                    h_f.append(JVP[a])

                if m_flag == 1:
                    if t == 0:
                        bar_h_f = []
                        for a, _ in enumerate(weighting_model.parameters()):
                            bar_h_f.append(h_f[a])
                    else:
                        new_bar_h_f = []
                        for a, _ in enumerate(weighting_model.parameters()):
                            new_bar_h_f.append(mu*h_f[a] + (1-mu)*bar_h_f[a])
                        bar_h_f = copy.deepcopy(new_bar_h_f)

                # update x
                for a, param in enumerate(weighting_model.parameters()):
                    if m_flag == 0:
                        param.data = param.data - alphas[alg_flag] * h_f[a]
                    elif m_flag == 1:
                        param.data = param.data - alphas[alg_flag] * bar_h_f[a]

        end_time = time.time()
        elapsed_time = end_time - start_time
        times_t.append(elapsed_time)

        with torch.no_grad():
            g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                                g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)

            f_loss = get_f_loss(y_optimizer, x_optimizer, main_model, f_loss_MSE, inputs_test, labels_test)

        train_loss_t.append(g_loss.item())
        test_loss_t.append(f_loss.item())

        for param in weighting_model.parameters():
            x_param = param.tolist()
            soft_t.append(np.squeeze(scipy.special.softmax(x_param)))

    train_loss_run.append(train_loss_t)
    test_loss_run.append(test_loss_t)
    soft_run.append(soft_t)
    times_run.append(times_t)
    print(f"(Elapsed time: {sum(times_t):.3f} seconds)")

if save_file:
    print(file_name)
    items_to_save = [train_loss_run, test_loss_run, soft_run, times_run]
    with open(f'{file_name}.pkl', 'wb') as file:
        pickle.dump(items_to_save, file)

cum_times_run = np.mean(np.cumsum(np.array(times_run), axis=1), axis=0)

plt.figure()
plt.plot(range(1, num_T+1), np.mean(np.array(train_loss_run), axis=0))
plt.plot(range(1, num_T+1),np.mean(np.array(test_loss_run), axis=0))
plt.legend(['Train (g_loss)','Test (f_loss)'])
plt.xlim([0, num_T])
plt.ylim([0, 100])
plt.grid()

plt.figure()
plt.plot(cum_times_run, np.mean(np.array(train_loss_run), axis=0))
plt.plot(cum_times_run, np.mean(np.array(test_loss_run), axis=0))
plt.legend(['Train (g_loss)','Test (f_loss)'])
plt.xlim([0, 40])
plt.ylim([0, 100])
plt.grid()

plt.figure()
plt.plot(range(0, num_T+1), np.mean(np.array(soft_run), axis=0))
plt.legend(['Dataset 1','Dataset 2'])
plt.xlim([0, num_T])
plt.ylim([0, 1])
plt.grid()