import time
import copy
import pickle
import scipy
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import random
import argparse
from transformers import AutoTokenizer
from datasets import load_dataset

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--offset', type=int, default=1, help='Offset to distinguish a saved file')
    args = parser.parse_args()
    return args
args = parse_args()

# configuration
save_file = 1

seeds = [1, 2, 3, 4, 5]
num_run = len(seeds)
num_step = 3000

# learning parameters
batch_size = 256
beta  = 0.00000001 # for y (neural network)

# penalty coefficients
u = 1
v = 1

# x_net
num_T = 3 # number of training sets
num_V = 5 # number of test sets

# preference vector
# pref = [0.96, 0.01, 0.01, 0.01, 0.01]
# pref = [0.84, 0.04, 0.04, 0.04, 0.04]


# y_net
in_dim = 500
width = 500
depth = 5
out_dim = 1

task = "RM"
file_name = task + f"_Run{num_run}_S{num_step}_T{num_T}_V{num_V}_pref{pref}"
print(file_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
print()

def build_y(input_dim: int, hidden_width: int, hidden_depth: int, output_dim: int):
    layers = []
    layers.append(nn.Linear(input_dim, hidden_width, bias=True))
    layers.append(nn.ReLU())
    for _ in range(hidden_depth):
        layers.append(nn.Linear(hidden_width, hidden_width, bias=True))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(hidden_width, output_dim, bias=False))
    return nn.Sequential(*layers)

def build_x(dim: int):
    layers = []
    layers.append(nn.Linear(1, dim, bias=False))
    layers.append(nn.Softmax(dim=1))
    return nn.Sequential(*layers)

def get_g_loss(y_opt, x_opt, y_model, x_model, loss_model, inputs_y, labels_y, one_tensor, mask_tensor):
    y_opt.zero_grad()
    x_opt.zero_grad()
    outputs_y = y_model(inputs_y)
    outputs_x = x_model(one_tensor, mask_tensor)
    g_loss = loss_model(outputs_y, labels_y, outputs_x)
    return g_loss

def get_f_loss(y_opt, x_opt, y_model, loss_model, inputs_y, labels_y):
    y_opt.zero_grad()
    x_opt.zero_grad()
    outputs_y = y_model(inputs_y)
    f_loss = loss_model(outputs_y, labels_y)
    return f_loss

def sort_data(dataset, max_size):
    idx = []
    prompt = []
    response =  []
    for i, a in enumerate(dataset):
        if len(a['prompt']) + len(a['response']) < max_size:
            idx.append(i)
            prompt.append(a['prompt'])
            response.append(a['response'])
    return idx, prompt, response

def get_scores(dataset):
    data = []
    for a in dataset.features:
        data.append(dataset[a])
    return np.array(data[2:])

def combine_data(prompt, response, Tokenizer, max_len):
    data_H_len = []
    data_H = []
    for a in range(len(prompt)):
        input_ids = Tokenizer(prompt[a], response[a]).input_ids
        data_H_len.append(len(input_ids))
        data_H.append(np.concatenate((input_ids, np.zeros([max_len-len(input_ids),])), axis=0))
    return np.array(data_H), np.array(data_H_len)


class y_model(nn.Module):
    def __init__(self, input_dim, hidden_width, hidden_depth, output_dim):
        super().__init__()
        self.y_net = build_y(input_dim, hidden_width, hidden_depth, output_dim)

    def forward(self, feature):
        output = self.y_net(feature)
        return output


class x_model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.x_net = build_x(dim)

    def forward(self, one, idx):
        weights = self.x_net(one)
        return torch.gather(weights, 1, idx)


class WeightedMSELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs, targets, weights):
        return torch.mean(weights * (inputs - targets)**2)
    

Tokenizer_name = "microsoft/deberta-v3-large"
Tokenizer_D = AutoTokenizer.from_pretrained(Tokenizer_name)
print()

ds_tra = load_dataset("nvidia/HelpSteer")['train']
ds_val = load_dataset("nvidia/HelpSteer")['validation']

# only use the data having less than 1000 characters
idx_tra, prompt_tra, response_tra = sort_data(ds_tra, 1000)         # len = 6524
idx_val, prompt_val, response_val = sort_data(ds_val, 1000)         # len = 320

scores_tra = get_scores(ds_tra)[:, np.array(idx_tra)].T             # shape = (6524, 5)
scores_val = get_scores(ds_val)[:, np.array(idx_val)].T             # shape = (320, 5)

train_data_H, train_data_H_len = combine_data(prompt_tra, response_tra, Tokenizer_D, 500)   # shape = (6524, 500), 6524
test_data_H, test_data_H_len = combine_data(prompt_val, response_val, Tokenizer_D, 500)     # shape = (320, 500), 320

# train data
np.random.seed(42)
rdm_score = np.random.randint(0, 5, size=(len(idx_tra),))
train_data = np.concatenate((train_data_H, train_data_H, train_data_H), axis=0)                         # shape = (3 * 6524, 500)
train_label = np.concatenate((scores_tra[:, 2], scores_tra[:, 4], rdm_score), axis=0)[:, np.newaxis]    # shape = (3 * 6524, 1)

# validation data
test_data = np.concatenate((test_data_H, test_data_H, test_data_H, test_data_H, test_data_H),
                           axis=0)                                                                      # shape = (5 * 320, 500)
test_label = np.concatenate((scores_val[:, 0], scores_val[:, 1], scores_val[:, 2],
                             scores_val[:, 3], scores_val[:, 4]), axis=0)[:, np.newaxis]                # shape = (5 * 320, 1)

# dataset numbering
dataset_num = np.zeros([np.shape(train_data)[0], 1])
dataset_num[6524:] = 1
dataset_num[-6524:] = 2

# testset numbering
testset_num = np.zeros([np.shape(test_data)[0], 1])
testset_num[320:640] = 1
testset_num[640:960] = 2
testset_num[960:1280] = 3
testset_num[1280:] = 4

# tensor
train_data = torch.as_tensor(train_data).float().to(device)
train_label = torch.as_tensor(train_label).float().to(device)
test_data = torch.as_tensor(test_data).float().to(device)
test_label = torch.as_tensor(test_label).float().to(device)
dataset_num = torch.as_tensor(dataset_num).long().to(device)
testset_num = torch.as_tensor(testset_num).long().to(device)
one_tensor = torch.as_tensor(np.ones([batch_size, 1])).float().to(device) # input to the weighting_model

train_loss_run = []
test_loss_run = []
soft_run = []
times_run = []


for run in range(num_run):

    seed = seeds[run]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    print(seed)

    print(f"Run {run+1} / {num_run}")
    main_model = y_model(in_dim, width, depth, out_dim).to(device)
    weighting_model = x_model(num_T).to(device)
    weighting_model.state_dict()['x_net.0.weight'].fill_(0)

    rho = 1
    delta = [0] * num_V

    y_optimizer = torch.optim.Adam(main_model.parameters())
    x_optimizer = torch.optim.Adam(weighting_model.parameters())
    f_loss_MSE  = torch.nn.MSELoss()
    g_loss_WMSE = WeightedMSELoss()

    train_loss_t = []
    test_loss_t = []
    soft_t = []
    times_t = []

    for param in weighting_model.parameters():
        x_param = param.tolist()
        soft_t.append(np.squeeze(scipy.special.softmax(x_param)))                       # [np.shape = (3,)]

    for t in range(num_step):
        batch_idx_train = np.random.choice(len(train_data), batch_size, replace=False)  # shape = (256,)
        inputs_train = train_data[batch_idx_train]                                      # torch.size([256, 500])
        labels_train = train_label[batch_idx_train]                                     # torch.size([256, 1])
        train_mask_tensor = dataset_num[batch_idx_train]                                # torch.size([256, 1])

        batch_idx_test = np.random.choice(len(test_data), batch_size, replace=False)    # shape = (256,)
        inputs_test = test_data[batch_idx_test]                                         # torch.size([256, 500])
        labels_test = test_label[batch_idx_test]                                        # torch.size([256, 1])
        test_mask_tensor = testset_num[batch_idx_test]                                  # torch.size([256, 1])

        inputs_grouped = {}
        labels_grouped = {}
        mask_grouped = {}
        for s in range(num_V):
            mask_s = (test_mask_tensor.squeeze() == s)  # shape: [256]
            inputs_grouped[s] = inputs_test[mask_s]  # shape: [num_s, 500]
            labels_grouped[s] = labels_test[mask_s]  # shape: [num_s, 1]
            mask_grouped[s] = test_mask_tensor[mask_s]
            if s > -1:
                mask_grouped[s].zero_()


        start_time = time.time()

        # compute g_y
        g_loss = get_g_loss(y_optimizer, x_optimizer, main_model, weighting_model,
                            g_loss_WMSE, inputs_train, labels_train, one_tensor, train_mask_tensor)     # (tensor(loss))
        g_y = torch.autograd.grad(g_loss, main_model.parameters(), create_graph=True)                   # tuple, len = 13 with: torch.Size([500, 500]), torch.Size([500])

        g_y_flat = torch.cat([gy.view(-1) for gy in g_y])                                               # torch.Size([1503500])
        g_v = torch.dot(g_y_flat, g_y_flat)
        HVP_weighting = torch.autograd.grad(g_v, weighting_model.parameters(), retain_graph=True)
        HVP_main = torch.autograd.grad(g_v, main_model.parameters())                

        grad_norm = g_y_flat.norm().detach() ** 2
        HVP_weighting = tuple(h / grad_norm for h in HVP_weighting)
        HVP_main = tuple(h / grad_norm for h in HVP_main)


        # compute f
        f_loss = []
        f_main = []
        for s in range(num_V):
            loss_temp = get_f_loss(y_optimizer, x_optimizer, main_model, f_loss_MSE,
                              inputs_grouped[s], labels_grouped[s])
            f_loss.append(loss_temp)
            f_main.append(torch.autograd.grad(loss_temp, main_model.parameters()))

        value = []
        for s in range(num_V):
            value.append(pref[s] * f_loss[s].item() + delta[s] - rho)



        # update models
        with torch.no_grad():
            for a, param in enumerate(weighting_model.parameters()):
                param.data = param.data - u * beta * HVP_weighting[a]

        with torch.no_grad():
            for a, param in enumerate(main_model.parameters()):
                for s in range(num_V):
                    param.data = param.data - u * beta * HVP_main[a] - v * beta * value[s] * pref[s] * f_main[s][a]

        rho = rho - beta * (1 - v * sum(value))
        for s in range(num_V):
            delta[s] = delta[s] - beta * value[s] * v
            if delta[s] < 0:
                delta[s] = 0


        end_time = time.time()
        elapsed_time = end_time - start_time
        times_t.append(elapsed_time)


        train_loss_t.append(g_loss.item())
        f_loss_vector = []
        for s in range(num_V):
            f_loss_vector.append(f_loss[s].item())
        test_loss_t.append(f_loss_vector)

        for param in weighting_model.parameters():
            x_param = param.tolist()
            soft_t.append(np.squeeze(scipy.special.softmax(x_param)))

        if (t+1) % 100 == 0:
            print(f"Finished {t+1} of {num_step} steps")
            print(f_loss_vector)

    train_loss_run.append(train_loss_t)
    test_loss_run.append(test_loss_t)
    soft_run.append(soft_t)
    times_run.append(times_t)
    print(f"(Elapsed time: {sum(times_t):.3f} seconds)")

    # print()
    # print(test_loss_run[0][:5])
    # print()
    # print(test_loss_run[0][-5:])


if save_file:
    print(file_name)
    items_to_save = [train_loss_run, test_loss_run, soft_run, times_run]
    with open(f'/Anonymous/{file_name}.pkl', 'wb') as file:
        pickle.dump(items_to_save, file)

