import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import time
import copy
import pickle
import scipy
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import argparse
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import random



############################################### Preliminaries ###############################################


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--offset', type=int, default=1, help='Offset to distinguish a saved file')
    args = parser.parse_args()
    return args
args = parse_args()

# configuration
save_file = 1

seeds = [1, 2, 3, 4, 5]

num_run = len(seeds)
num_step = 3000

# penalty terms
u = 10
v = 10

# preference
# pref = [0.96, 0.01, 0.01, 0.01, 0.01]
# pref = [0.84, 0.04, 0.04, 0.04, 0.04]

# learning parameters
batch_size = 32
alpha = 0.00001 # for x (weighting)
beta  = 0.00001 # for y (neural network)

# x_net
num_T = 2 # number of training sets
num_V = 5 # number of validation sets

task = "LLM"
file_name = task + f"_Run{num_run}_S{num_step}_T{num_T}_pref{pref}"
print(file_name)

device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
data_type = torch.bfloat16
print(device)
print()

config_lora = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules='all-linear'
)

def build_x(dim: int):
    layers = []
    layers.append(nn.Linear(1, dim, bias=False))
    layers.append(nn.Softmax(dim=1))
    return nn.Sequential(*layers)

def get_outputs(y_opt, x_opt, y_model, x_model, inputs_y_ids, input_y_masks, one_tensor, mask_tensor):
    y_opt.zero_grad()
    x_opt.zero_grad()
    outputs_y = y_model(input_ids=inputs_y_ids.to(y_model.device), attention_mask=input_y_masks.to(y_model.device))
    outputs_x = x_model(one_tensor, mask_tensor)
    return outputs_y, outputs_x

def get_output(y_opt, x_opt, y_model, inputs_y_ids, input_y_masks):
    y_opt.zero_grad()
    x_opt.zero_grad()
    outputs_y = y_model(input_ids=inputs_y_ids.to(y_model.device), attention_mask=input_y_masks.to(y_model.device))
    return outputs_y

class x_model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.x_net = build_x(dim)

    def forward(self, one, idx):
        weights = self.x_net(one)
        return torch.gather(weights, 1, idx)


class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, ignore_idx):
        super().__init__()
        self.entropy_net = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx)

    def forward(self, inputs, targets, weights):
        c_entropy = self.entropy_net(inputs, targets)
        return torch.mean(weights * c_entropy)


ignore_idx = -100



################################################## Dataset ##################################################


def make_template():
    templates = []
    for s in range(num_V):
        templates_temp = []
        for a in scores_val_idx[s]:
            chat_template = []
            chat_template.append({"role": "system", "content": "You are a chatbot who answers a given question."})
            chat_template.append({"role": "user", "content": prompt_val[a]})
            chat_template.append({"role": "assistant", "content": response_val[a]})
            templates_temp.append(chat_template)
        templates.append(templates_temp)
    return templates

def sort_data(dataset, max_size):
    idx = []
    prompt = []
    response =  []
    for i, a in enumerate(dataset):
        if len(a['prompt']) + len(a['response']) < max_size:
            idx.append(i)
            prompt.append(a['prompt'])
            response.append(a['response'])
    return idx, prompt, response

def get_scores(dataset):
    data = []
    for a in dataset.features:
        data.append(dataset[a])
    return np.array(data[2:])

def combine_data(prompt, response, Tokenizer, max_len):
    data_H_len = []
    data_H = []
    for a in range(len(prompt)):
        input_ids = Tokenizer(prompt[a], response[a]).input_ids
        data_H_len.append(len(input_ids))
        data_H.append(np.concatenate((input_ids, np.zeros([max_len-len(input_ids),])), axis=0))
    return np.array(data_H), np.array(data_H_len)


LLM_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(LLM_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

ds_tra = load_dataset("nvidia/HelpSteer")['train']
ds_val = load_dataset("nvidia/HelpSteer")['validation']

idx_tra, prompt_tra, response_tra = sort_data(ds_tra, 1000)         # len = 6524
idx_val, prompt_val, response_val = sort_data(ds_val, 1000)         # len = 320

scores_tra = get_scores(ds_tra)[:, np.array(idx_tra)]                       # shape = (5, 6524)
scores_val = get_scores(ds_val)[:, np.array(idx_val)]                       # shape = (5, 320)

# Train
scores_tra_high_idx = np.where(np.average(scores_tra, axis=0) > 2.5)[0]     # shape = (2660,)
scores_tra_low_idx = np.where(np.average(scores_tra, axis=0) <= 2)[0]       # shape = (2394,)

templates = []
for idx_list in [scores_tra_high_idx, scores_tra_low_idx]:
    templates_temp = []
    for idx in idx_list:
        chat_template = []
        chat_template.append({"role": "system", "content": "You are a chatbot who answers a given question."})
        chat_template.append({"role": "user", "content": prompt_tra[idx]})
        chat_template.append({"role": "assistant", "content": response_tra[idx]})
        templates_temp.append(chat_template)
    templates.append(templates_temp)
template_tra = templates[0] + templates[1]
train_num = np.ones([len(templates[0]) + len(templates[1]), 1])
train_num[:len(template_tra[0])] = 0                                        # shape = (5054, 1)


train_input = []
train_label = []
input_maxlen = 0
for a in range(len(template_tra)):
    prompt = tokenizer.apply_chat_template(template_tra[a][:-1], tokenize=False, add_generation_prompt=True).replace(
        tokenizer.bos_token, "")
    output_p = tokenizer(prompt, padding=False, max_length=64)
    prompt_id = output_p.input_ids

    response = tokenizer.apply_chat_template(template_tra[a], tokenize=False, add_generation_prompt=False).replace(
        tokenizer.bos_token, "")
    output_r = tokenizer(response[len(prompt):], padding=False, max_length=64)
    response_id = output_r.input_ids[1:]

    input_id = prompt_id + response_id
    label_id = [-100] * len(prompt_id) + response_id

    if len(input_id) > input_maxlen:
        input_maxlen = len(input_id)

    train_input.append(input_id)
    train_label.append(label_id)

maxlen = input_maxlen + 5                                                   # 132
train_mask = []
for a in range(len(template_tra)):
    input_id = train_input[a]
    label_id = train_label[a]

    input_id = input_id + [tokenizer.pad_token_id] * (maxlen - len(input_id))
    label_id = label_id + [-100] * (maxlen - len(label_id))
    attention_mask = [1 if t != tokenizer.pad_token_id else 0 for t in input_id]

    train_input[a] = input_id
    train_label[a] = label_id
    train_mask.append(attention_mask)

train_input = torch.as_tensor(np.array(train_input))                # shape = torch.Size([5054, 132])
train_label = torch.as_tensor(np.array(train_label))                # shape = torch.Size([5054, 132])
train_mask = torch.as_tensor(np.array(train_mask))                  # shape = torch.Size([5054, 132])


# Test
scores_val = get_scores(ds_val)[:, np.array(idx_val)]                       # shape = (5, 320)
scores_val_idx = [np.where((row == 3) | (row == 4))[0].tolist() for row in scores_val]
template_val = make_template()
val_data = template_val[0] + template_val[1] + template_val[2] + template_val[3] + template_val[4]
val_num_0 = np.zeros([len(template_val[0]), 1])
val_num_1 = np.ones([len(template_val[1]), 1])
val_num_2 = np.zeros([len(template_val[2]), 1])
val_num_2[:] = 2
val_num_3 = np.zeros([len(template_val[3]), 1])
val_num_3[:] = 3
val_num_4 = np.zeros([len(template_val[4]), 1])
val_num_4[:] = 4
test_num = np.concatenate((val_num_0, val_num_1, val_num_2, val_num_3, val_num_4), axis=0)   # shape = (689, 1) <192 193 262 21 21>

class_indices = [
    list(range(0, 192)),
    list(range(192, 192 + 193)),
    list(range(385, 385 + 262)),
    list(range(647, 647 + 21)),
    list(range(668, 668 + 21))
]
total_samples = 192 + 193 + 262 + 21 + 21
base_samples_per_class = batch_size // num_V
remainder = batch_size % num_V
samples_per_class = [base_samples_per_class] * num_V
if remainder > 0:
    extra_sample_classes = np.random.choice(num_V, remainder, replace=False)
    for class_idx in extra_sample_classes:
        samples_per_class[class_idx] += 1


test_input = []
test_label = []
test_mask = []
for a in range(len(val_data)):
    prompt = tokenizer.apply_chat_template(val_data[a][:-1], tokenize=False, add_generation_prompt=True).replace(tokenizer.bos_token, "")
    output_p = tokenizer(prompt, padding=False, max_length=64)
    prompt_id = output_p.input_ids

    response = tokenizer.apply_chat_template(val_data[a], tokenize=False, add_generation_prompt=False).replace(tokenizer.bos_token, "")
    output_r = tokenizer(response[len(prompt):], padding=False, max_length=64)
    response_id = output_r.input_ids[1:]

    input_id = prompt_id + response_id
    label_id = [-100] * len(prompt_id) + response_id

    input_id = input_id + [tokenizer.pad_token_id] * (maxlen - len(input_id))
    label_id = label_id + [-100] * (maxlen - len(label_id))
    attention_mask = [1 if t != tokenizer.pad_token_id else 0 for t in input_id]

    test_input.append(input_id)
    test_label.append(label_id)
    test_mask.append(attention_mask)

test_input = torch.as_tensor(np.array(test_input))                  # shape = torch.Size([689, 132])
test_label = torch.as_tensor(np.array(test_label))                  # shape = torch.Size([689, 132])
test_mask = torch.as_tensor(np.array(test_mask))                    # shape = torch.Size([689, 132])


# tensor
train_num = torch.as_tensor(train_num).long()                       # shape = torch.Size([5054, 1])
test_num = torch.as_tensor(test_num).long()                         # shape = torch.Size([689, 1])
one_tensor = torch.as_tensor(np.ones([batch_size, 1])).float()      # input to the weighting_model


del ds_tra, ds_val
print()


train_loss_run = []
test_loss_run = []
soft_run = []
times_run = []



################################################## Running ##################################################


for run in range(num_run):

    seed = seeds[run]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    print(seed)

    print(f"Run {run+1} / {num_run}")
    main_model = AutoModelForCausalLM.from_pretrained(LLM_name, device_map=device, torch_dtype=data_type)
    main_model.config.pad_token_id = tokenizer.bos_token_id
    main_model = get_peft_model(main_model, config_lora)
    main_model.print_trainable_parameters()
    lora_params = [p for p in main_model.parameters() if p.requires_grad]
    weighting_model = x_model(num_T).to(device)
    weighting_model.state_dict()['x_net.0.weight'].fill_(0)

    rho = 1
    delta = [0] * num_V

    y_optimizer = torch.optim.Adam(main_model.parameters())
    x_optimizer = torch.optim.Adam(weighting_model.parameters())
    f_loss_MSE  = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx)
    g_loss_WMSE = WeightedCrossEntropyLoss(ignore_idx=ignore_idx)

    train_loss_t = []
    test_loss_t = []
    soft_t = []
    times_t = []

    for param in weighting_model.parameters():
        x_param = param.tolist()
        soft_t.append(np.squeeze(scipy.special.softmax(x_param)))

    for t in range(num_step):
        batch_idx_train = np.random.choice(len(train_input), batch_size, replace=False)
        batch_train_input = train_input[batch_idx_train, :]                                 # shape = torch.Size([32, 132])
        batch_train_label = train_label[batch_idx_train, :]                                 # shape = torch.Size([32, 132])
        batch_train_mask = train_mask[batch_idx_train, :]                                   # shape = torch.Size([32, 132])
        train_mask_tensor = train_num[batch_idx_train]                                      # shape = torch.Size([32, 1])

        # batch_idx_test = np.random.choice(len(test_input), batch_size, replace=False)
        final_indices = []
        for i in range(num_V):
            num_to_sample = samples_per_class[i]
            selected = np.random.choice(class_indices[i], num_to_sample, replace=False)
            final_indices.extend(selected)
        batch_idx_test = np.array(final_indices)
        np.random.shuffle(batch_idx_test)

        batch_test_input = test_input[batch_idx_test, :]
        batch_test_label = test_label[batch_idx_test, :]
        batch_test_mask = test_mask[batch_idx_test, :]
        test_mask_tensor = test_num[batch_idx_test]

        inputs_grouped = {}
        labels_grouped = {}
        mask_batch_grouped = {}
        mask_grouped = {}
        for s in range(num_V):
            mask_s = (test_mask_tensor.squeeze() == s)
            inputs_grouped[s] = batch_test_input[mask_s]
            labels_grouped[s] = batch_test_label[mask_s]
            mask_batch_grouped[s] = batch_test_mask[mask_s]
            mask_grouped[s] = test_mask_tensor[mask_s]
            if s > -1:
                mask_grouped[s].zero_()


        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
            start_time = time.time()


            # compute g_y
            outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                               batch_train_input, batch_train_mask, one_tensor.to(device),
                                               train_mask_tensor.to(device))
            logits = outputs_y.logits
            logits_flat = logits.view(-1, logits.size(-1))
            labels_flat = batch_train_label.view(-1).long().to(device)
            g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)
            g_y = torch.autograd.grad(g_loss, lora_params, create_graph=True)

            g_y_flat = torch.cat([gy.view(-1) for gy in g_y])                               # shape = torch.Size([5636096])
            g_v = torch.dot(g_y_flat, g_y_flat)
            HVP_weighting = torch.autograd.grad(g_v, weighting_model.parameters(), retain_graph=True)   # tuple
            HVP_main = torch.autograd.grad(g_v, lora_params)                                            # tuple

            grad_norm = g_y_flat.norm().detach() ** 2
            HVP_weighting = tuple(h / grad_norm for h in HVP_weighting)
            HVP_main = tuple(h / grad_norm for h in HVP_main)


            # compute f_y
            f_loss = []
            f_main = []
            for s in range(num_V):
                outputs_y = get_output(y_optimizer, x_optimizer, main_model,
                                       inputs_grouped[s], mask_batch_grouped[s])
                logits = outputs_y.logits
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = labels_grouped[s].view(-1).long().to(device)
                temp_loss = f_loss_MSE(logits_flat, labels_flat)
                f_loss.append(temp_loss)
                f_main.append(torch.autograd.grad(temp_loss, lora_params))


            # update models
            value = []
            for s in range(num_V):
                value.append(pref[s] * f_loss[s].item() + delta[s] - rho)

            with torch.no_grad():
                for a, param in enumerate(weighting_model.parameters()):
                    param.data = param.data - u * beta * HVP_weighting[a]

            with torch.no_grad():
                for a, param in enumerate(lora_params):
                    for s in range(num_V):
                        param.data = param.data - u * beta * HVP_main[a] - v * beta * value[s] * pref[s] * \
                                     f_main[s][a]

            rho = rho - beta * (1 - v * sum(value))
            for s in range(num_V):
                delta[s] = delta[s] - beta * value[s] * v
                if delta[s] < 0:
                    delta[s] = 0

            end_time = time.time()
            elapsed_time = end_time - start_time
            times_t.append(elapsed_time)

        del outputs_x, outputs_y, logits, logits_flat, temp_loss, f_main, g_y, g_y_flat, HVP_main, HVP_weighting, grad_norm


        with torch.no_grad():
            train_loss_t.append(g_loss.item())
            f_loss_vector = []
            for s in range(num_V):
                f_loss_vector.append(f_loss[s].item())
            test_loss_t.append(f_loss_vector)

            for param in weighting_model.parameters():
                x_param = param.tolist()
                soft_t.append(np.squeeze(scipy.special.softmax(x_param)))


            if (t + 1) % 30 == 0:
                print(f"Finished {t + 1} of {num_step} steps")
                print(f_loss_vector)

    train_loss_run.append(train_loss_t)
    test_loss_run.append(test_loss_t)
    soft_run.append(soft_t)
    times_run.append(times_t)
    print(f"(Elapsed time: {sum(times_t):.3f} seconds)")


if save_file:
    print(file_name)
    items_to_save = [train_loss_run, test_loss_run, soft_run, times_run]
    with open(f'/Anonymous/{file_name}.pkl', 'wb') as file:
        pickle.dump(items_to_save, file)
