import time
import copy
import pickle
import scipy
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import argparse
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--offset', type=int, default=1, help='Offset to distinguish a saved file')
    parser.add_argument('--alg', type=int, default=0, help='Algorithm type')
    args = parser.parse_args()
    return args
args = parse_args()

# configuration
save_file = 0

algs = ["Lazy", "AmIGO", "SOBA"]
alg_flag = args.alg # algorithms
m_flag = 0 # momentum flag
lazy_flag = 1 # lazy JVP

num_run = 10
num_Ts = [900, 90, 300]
num_T = num_Ts[alg_flag]

# learning parameters
batch_size = 32
alphas = [0.005, 0.005, 0.005] # for x (weighting)
betas  = [0.0002, 0.0002, 0.0002] # for y (MLP)
gammas = [0.0003, 0.0003, 0.0003] # for z (auxiliary)

mu = 0.8 # momentum
lazy_N = 5
amigo_M = 5
amigo_N = 5

if m_flag == 0:
    mu = 0.0

if alg_flag != 0:
    lazy_N = 0

# x_net
num_DS = 2

if alg_flag != 0: # only SO-Lazy has lazy JVP
    lazy_flag = 0

if alg_flag == 1: # AmIGO does not have momentum
    m_flag = 0

task = "LLM"
M_str = ["", "-M"]
L_str = ["", "-L"]
file_name = task + "_" + algs[alg_flag] + L_str[lazy_flag] + M_str[m_flag] + f"_Run{num_run}_T{num_T}_DS{num_DS}_mu{int(mu*10)}_N{lazy_N}"
print(file_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_type = torch.bfloat16

config_lora = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules='all-linear'
)

def build_x(dim: int):
    layers = []
    layers.append(nn.Linear(1, dim, bias=False))
    layers.append(nn.Softmax(dim=1))
    return nn.Sequential(*layers)

def get_outputs(y_opt, x_opt, y_model, x_model, inputs_y_ids, input_y_masks, one_tensor, mask_tensor):
    y_opt.zero_grad()
    x_opt.zero_grad()
    outputs_y = y_model(input_ids=inputs_y_ids.to(y_model.device), attention_mask=input_y_masks.to(y_model.device))
    outputs_x = x_model(one_tensor, mask_tensor)
    return outputs_y, outputs_x

def get_output(y_opt, x_opt, y_model, inputs_y_ids, input_y_masks):
    y_opt.zero_grad()
    x_opt.zero_grad()
    outputs_y = y_model(input_ids=inputs_y_ids.to(y_model.device), attention_mask=input_y_masks.to(y_model.device))
    return outputs_y

def printIndex(str, s):
    idx = []
    for i in range(len(str)):
        if (str[i:i + len(s)] == s):
            idx.append(i)
    return idx

def make_template(dataset, max_size):
    idx = []
    templates = []
    for a in range(len(dataset)):
        if len(dataset[a]) < max_size:
            idx.append(a)
            chat_template = []
            chat_template.append({"role": "system", "content": "You are a chatbot who answers a given question."})
            idx_H = printIndex(dataset[a], "\n\nHuman:")
            idx_A = printIndex(dataset[a], "\n\nAssistant:")
            chat_template.append({"role": "user", "content": dataset[a][9:idx_A[0]]})
            if len(idx_H) > 1:
                chat_template.append({"role": "assistant", "content": dataset[a][idx_A[0]+13:idx_H[1]]})
            else:
                chat_template.append({"role": "assistant", "content": dataset[a][idx_A[0]+13:]})
            templates.append(chat_template)
    return idx, templates


class x_model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.x_net = build_x(dim)

    def forward(self, one, idx):
        weights = self.x_net(one)
        return torch.gather(weights, 1, idx)


class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, ignore_idx):
        super().__init__()
        self.entropy_net = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx)

    def forward(self, inputs, targets, weights):
        c_entropy = self.entropy_net(inputs, targets)
        return torch.mean(weights * c_entropy)


dataset = load_dataset("Anthropic/hh-rlhf")
train_chosen = dataset['train'][:]['chosen']
train_rejected = dataset['train'][:]['rejected']

LLM_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(LLM_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

idx_C, template_C = make_template(train_chosen, 150)
idx_R, template_R = make_template(train_rejected, 150)

del train_chosen, train_rejected

# train data
train_data = template_C + template_R
test_data = template_C

# dataset numbering
dataset_num = np.zeros([np.shape(train_data)[0], 1])
dataset_num[len(idx_R):] = 1

train_input = []
train_label = []
input_maxlen = 0
for a in range(len(train_data)):
    prompt = tokenizer.apply_chat_template(train_data[a][:-1], tokenize=False, add_generation_prompt=True).replace(tokenizer.bos_token, "")
    output_p = tokenizer(prompt, padding=False, max_length=64)
    prompt_id = output_p.input_ids

    response = tokenizer.apply_chat_template(train_data[a], tokenize=False, add_generation_prompt=False).replace(tokenizer.bos_token, "")
    output_r = tokenizer(response[len(prompt):], padding=False, max_length=64)
    response_id = output_r.input_ids[1:]

    input_id = prompt_id + response_id
    label_id = [-100] * len(prompt_id) + response_id

    if len(input_id) > input_maxlen:
        input_maxlen = len(input_id)

    train_input.append(input_id)
    train_label.append(label_id)

maxlen = input_maxlen + 5
train_mask = []
for a in range(len(train_data)):
    input_id = train_input[a]
    label_id = train_label[a]

    input_id = input_id + [tokenizer.pad_token_id] * (maxlen - len(input_id))
    label_id = label_id + [-100] * (maxlen - len(label_id))
    attention_mask = [1 if t != tokenizer.pad_token_id else 0 for t in input_id]

    train_input[a] = input_id
    train_label[a] = label_id
    train_mask.append(attention_mask)

train_input = torch.as_tensor(np.array(train_input))
train_label = torch.as_tensor(np.array(train_label))
train_mask = torch.as_tensor(np.array(train_mask))

test_input = []
test_label = []
test_mask = []
for a in range(len(test_data)):
    prompt = tokenizer.apply_chat_template(test_data[a][:-1], tokenize=False, add_generation_prompt=True).replace(tokenizer.bos_token, "")
    output_p = tokenizer(prompt, padding=False, max_length=64)
    prompt_id = output_p.input_ids

    response = tokenizer.apply_chat_template(test_data[a], tokenize=False, add_generation_prompt=False).replace(tokenizer.bos_token, "")
    output_r = tokenizer(response[len(prompt):], padding=False, max_length=64)
    response_id = output_r.input_ids[1:]

    input_id = prompt_id + response_id
    label_id = [-100] * len(prompt_id) + response_id

    input_id = input_id + [tokenizer.pad_token_id] * (maxlen - len(input_id))
    label_id = label_id + [-100] * (maxlen - len(label_id))
    attention_mask = [1 if t != tokenizer.pad_token_id else 0 for t in input_id]

    test_input.append(input_id)
    test_label.append(label_id)
    test_mask.append(attention_mask)

test_input = torch.as_tensor(np.array(test_input))
test_label = torch.as_tensor(np.array(test_label))
test_mask = torch.as_tensor(np.array(test_mask))

# tensor
dataset_num = torch.as_tensor(dataset_num).long()
one_tensor = torch.as_tensor(np.ones([batch_size, 1])).float() # input to the weighting_model

train_loss_run = []
test_loss_run = []
soft_run = []
times_run = []

ignore_idx = -100


for run in range(num_run):
    print(f"Run {run+1} / {num_run}")
    main_model = AutoModelForCausalLM.from_pretrained(LLM_name, device_map=device, torch_dtype=data_type)
    main_model.config.pad_token_id = tokenizer.bos_token_id
    main_model = get_peft_model(main_model, config_lora)
    main_model.print_trainable_parameters()
    lora_params = [p for p in main_model.parameters() if p.requires_grad]
    weighting_model = x_model(num_DS).to(device)
    weighting_model.state_dict()['x_net.0.weight'].fill_(0)
    # if run == 0:
    #     print(main_model)
    #     print(weighting_model)

    z_params = []
    for param in lora_params:
        z_params.append(torch.rand_like(param))

    y_optimizer = torch.optim.Adam(main_model.parameters())
    x_optimizer = torch.optim.Adam(weighting_model.parameters())
    f_loss_MSE  = torch.nn.CrossEntropyLoss(ignore_index=ignore_idx)
    g_loss_WMSE = WeightedCrossEntropyLoss(ignore_idx=ignore_idx)

    train_loss_t = []
    test_loss_t = []
    soft_t = []
    times_t = []

    for param in weighting_model.parameters():
        x_param = param.tolist()
        soft_t.append(np.squeeze(scipy.special.softmax(x_param)))

    for t in range(num_T):
        batch_idx_train = np.random.choice(len(train_input), batch_size, replace=False)
        batch_train_input = train_input[batch_idx_train, :]
        batch_train_label = train_label[batch_idx_train, :]
        batch_train_mask = train_mask[batch_idx_train, :]
        train_mask_tensor = dataset_num[batch_idx_train]

        batch_idx_test = np.random.choice(len(test_input), batch_size, replace=False)
        batch_test_input = test_input[batch_idx_test, :]
        batch_test_label = test_label[batch_idx_test, :]
        batch_test_mask = test_mask[batch_idx_test, :]

        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
            start_time = time.time()
            if alg_flag == 0:
                if (t % lazy_N) == 0:
                    # compute g_yy*z (first term of h_q)
                    outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                        batch_train_input, batch_train_mask, one_tensor.to(device),
                                        train_mask_tensor.to(device))
                    logits = outputs_y.logits
                    logits_flat = logits.view(-1, logits.size(-1))
                    labels_flat = batch_train_label.view(-1).long().to(device)
                    g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                    grad_g_y = torch.autograd.grad(g_loss, lora_params, create_graph=True)
                    HVP = torch.autograd.grad(grad_g_y, lora_params, grad_outputs=z_params)

                    # compute f_y (second term of h_q)
                    outputs_y = get_output(y_optimizer, x_optimizer, main_model,
                                        batch_test_input, batch_test_mask)
                    logits = outputs_y.logits
                    logits_flat = logits.view(-1, logits.size(-1))
                    labels_flat = batch_test_label.view(-1).long().to(device)
                    f_loss = f_loss_MSE(logits_flat, labels_flat)

                    grad_f_y = torch.autograd.grad(f_loss, lora_params)

                    with torch.no_grad():
                        # compute h_q
                        h_q = []
                        for a in range(len(grad_f_y)):
                            h_q.append(HVP[a] + grad_f_y[a])

                        # update z
                        for a in range(len(z_params)):
                            z_params[a] = z_params[a] - gammas[alg_flag] * h_q[a]

                    if lazy_flag == 1:
                        # update v (compute g_xy*z)
                        outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                            batch_train_input, batch_train_mask, one_tensor.to(device),
                                            train_mask_tensor.to(device))
                        logits = outputs_y.logits
                        logits_flat = logits.view(-1, logits.size(-1))
                        labels_flat = batch_train_label.view(-1).long().to(device)
                        g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                        grad_g_y = torch.autograd.grad(g_loss, lora_params, create_graph=True)
                        JVP = torch.autograd.grad(grad_g_y, weighting_model.parameters(), grad_outputs=z_params)

                if lazy_flag == 0:
                    # update v (compute g_xy*z)
                    outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                        batch_train_input, batch_train_mask, one_tensor.to(device),
                                        train_mask_tensor.to(device))
                    logits = outputs_y.logits
                    logits_flat = logits.view(-1, logits.size(-1))
                    labels_flat = batch_train_label.view(-1).long().to(device)
                    g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                    grad_g_y = torch.autograd.grad(g_loss, lora_params, create_graph=True)
                    JVP = torch.autograd.grad(grad_g_y, weighting_model.parameters(), grad_outputs=z_params)

                # compute h_g
                outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                    batch_train_input, batch_train_mask, one_tensor.to(device),
                                    train_mask_tensor.to(device))
                logits = outputs_y.logits
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = batch_train_label.view(-1).long().to(device)
                g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                h_g = torch.autograd.grad(g_loss, lora_params)

                # update y
                with torch.no_grad():
                    for a, param in enumerate(lora_params):
                        param.data = param.data - betas[alg_flag] * h_g[a]

                with torch.no_grad():
                    # compute h_f
                    h_f = []
                    for a, _ in enumerate(weighting_model.parameters()):
                        # h_f.append(grad_f_x[a] + JVP[a])
                        h_f.append(JVP[a])

                    if m_flag == 1:
                        if t == 0:
                            bar_h_f = []
                            for a, _ in enumerate(weighting_model.parameters()):
                                bar_h_f.append(h_f[a])
                        else:
                            new_bar_h_f = []
                            for a, _ in enumerate(weighting_model.parameters()):
                                new_bar_h_f.append(mu*h_f[a] + (1-mu)*bar_h_f[a])
                            bar_h_f = copy.deepcopy(new_bar_h_f)

                    # update x
                    for a, param in enumerate(weighting_model.parameters()):
                        if m_flag == 0:
                            param.data = param.data - alphas[alg_flag] * h_f[a]
                        elif m_flag == 1:
                            param.data = param.data - alphas[alg_flag] * bar_h_f[a]

            elif alg_flag == 1:
                # compute f_y (second term of h_q)
                outputs_y = get_output(y_optimizer, x_optimizer, main_model,
                                    batch_test_input, batch_test_mask)
                logits = outputs_y.logits
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = batch_test_label.view(-1).long().to(device)
                f_loss = f_loss_MSE(logits_flat, labels_flat)

                grad_f_y = torch.autograd.grad(f_loss, lora_params)

                for a in range(amigo_M):
                    # compute g_yy*z (first term of h_q)
                    outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                        batch_train_input, batch_train_mask, one_tensor.to(device),
                                        train_mask_tensor.to(device))
                    logits = outputs_y.logits
                    logits_flat = logits.view(-1, logits.size(-1))
                    labels_flat = batch_train_label.view(-1).long().to(device)
                    g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                    grad_g_y = torch.autograd.grad(g_loss, lora_params, create_graph=True)
                    HVP = torch.autograd.grad(grad_g_y, lora_params, grad_outputs=z_params)

                    with torch.no_grad():
                        # compute h_q
                        h_q = []
                        for a in range(len(grad_f_y)):
                            h_q.append(HVP[a] + grad_f_y[a])
                        
                        # update z
                        for a in range(len(z_params)):
                            z_params[a] = z_params[a] - gammas[alg_flag] * h_q[a]

                for a in range(amigo_N):
                    # compute h_g
                    outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                        batch_train_input, batch_train_mask, one_tensor.to(device),
                                        train_mask_tensor.to(device))
                    logits = outputs_y.logits
                    logits_flat = logits.view(-1, logits.size(-1))
                    labels_flat = batch_train_label.view(-1).long().to(device)
                    g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                    h_g = torch.autograd.grad(g_loss, lora_params)

                    # update y
                    with torch.no_grad():
                        for a, param in enumerate(lora_params):
                            param.data = param.data - betas[alg_flag] * h_g[a]

                # update v (compute g_xy*z)
                outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                    batch_train_input, batch_train_mask, one_tensor.to(device),
                                    train_mask_tensor.to(device))
                logits = outputs_y.logits
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = batch_train_label.view(-1).long().to(device)
                g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                grad_g_y = torch.autograd.grad(g_loss, lora_params, create_graph=True)
                JVP = torch.autograd.grad(grad_g_y, weighting_model.parameters(), grad_outputs=z_params)

                with torch.no_grad():
                    # compute h_f
                    h_f = []
                    for a, _ in enumerate(weighting_model.parameters()):
                        # h_f.append(grad_f_x[a] + JVP[a])
                        h_f.append(JVP[a])
    
                    # update x
                    for a, param in enumerate(weighting_model.parameters()):
                        param.data = param.data - alphas[alg_flag] * h_f[a]

            elif alg_flag == 2:
                # compute g_yy*z (first term of h_q)
                outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                    batch_train_input, batch_train_mask, one_tensor.to(device),
                                    train_mask_tensor.to(device))
                logits = outputs_y.logits
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = batch_train_label.view(-1).long().to(device)
                g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                grad_g_y = torch.autograd.grad(g_loss, lora_params, create_graph=True)
                HVP = torch.autograd.grad(grad_g_y, lora_params, grad_outputs=z_params)

                # compute f_y (second term of h_q)
                outputs_y = get_output(y_optimizer, x_optimizer, main_model,
                                    batch_test_input, batch_test_mask)
                logits = outputs_y.logits
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = batch_test_label.view(-1).long().to(device)
                f_loss = f_loss_MSE(logits_flat, labels_flat)

                grad_f_y = torch.autograd.grad(f_loss, lora_params)

                with torch.no_grad():
                    # compute h_q
                    h_q = []
                    for a in range(len(grad_f_y)):
                        h_q.append(HVP[a] + grad_f_y[a])

                    # update z
                    for a in range(len(z_params)):
                        z_params[a] = z_params[a] - gammas[alg_flag] * h_q[a]

                # compute h_g
                outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                    batch_train_input, batch_train_mask, one_tensor.to(device),
                                    train_mask_tensor.to(device))
                logits = outputs_y.logits
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = batch_train_label.view(-1).long().to(device)
                g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                h_g = torch.autograd.grad(g_loss, lora_params)

                # update y
                with torch.no_grad():
                    for a, param in enumerate(lora_params):
                        param.data = param.data - betas[alg_flag] * h_g[a]

                # update v (compute g_xy*z)
                outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                    batch_train_input, batch_train_mask, one_tensor.to(device),
                                    train_mask_tensor.to(device))
                logits = outputs_y.logits
                logits_flat = logits.view(-1, logits.size(-1))
                labels_flat = batch_train_label.view(-1).long().to(device)
                g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

                grad_g_y = torch.autograd.grad(g_loss, lora_params, create_graph=True)
                JVP = torch.autograd.grad(grad_g_y, weighting_model.parameters(), grad_outputs=z_params)

                with torch.no_grad():
                    # compute h_f
                    h_f = []
                    for a, _ in enumerate(weighting_model.parameters()):
                        # h_f.append(grad_f_x[a] + JVP[a])
                        h_f.append(JVP[a])

                    if m_flag == 1:
                        if t == 0:
                            bar_h_f = []
                            for a, _ in enumerate(weighting_model.parameters()):
                                bar_h_f.append(h_f[a])
                        else:
                            new_bar_h_f = []
                            for a, _ in enumerate(weighting_model.parameters()):
                                new_bar_h_f.append(mu*h_f[a] + (1-mu)*bar_h_f[a])
                            bar_h_f = copy.deepcopy(new_bar_h_f)

                    # update x
                    for a, param in enumerate(weighting_model.parameters()):
                        if m_flag == 0:
                            param.data = param.data - alphas[alg_flag] * h_f[a]
                        elif m_flag == 1:
                            param.data = param.data - alphas[alg_flag] * bar_h_f[a]

            end_time = time.time()
            elapsed_time = end_time - start_time
            times_t.append(elapsed_time)

        with torch.no_grad():
            outputs_y, outputs_x = get_outputs(y_optimizer, x_optimizer, main_model, weighting_model,
                                batch_train_input, batch_train_mask, one_tensor.to(device),
                                train_mask_tensor.to(device))

            logits = outputs_y.logits
            logits_flat = logits.view(-1, logits.size(-1))
            labels_flat = batch_train_label.view(-1).long().to(device)

            g_loss = g_loss_WMSE(logits_flat, labels_flat, outputs_x)

            outputs_y = get_output(y_optimizer, x_optimizer, main_model,
                                batch_test_input, batch_test_mask)

            logits = outputs_y.logits
            logits_flat = logits.view(-1, logits.size(-1))
            labels_flat = batch_test_label.view(-1).long().to(device)

            f_loss = f_loss_MSE(logits_flat, labels_flat)

        train_loss_t.append(g_loss.item())
        test_loss_t.append(f_loss.item())

        for param in weighting_model.parameters():
            x_param = param.tolist()
            soft_t.append(np.squeeze(scipy.special.softmax(x_param)))

        if (t % 5) == 0:
            print(f"t = {t} / g_loss = {train_loss_t[-1]:.3f} / f_loss = {test_loss_t[-1]:.3f} / weight = {soft_t[-1][0]:.3f}")

    train_loss_run.append(train_loss_t)
    test_loss_run.append(test_loss_t)
    soft_run.append(soft_t)
    times_run.append(times_t)
    print(f"(Elapsed time: {sum(times_t):.3f} seconds)")

if save_file:
    print(file_name)
    items_to_save = [train_loss_run, test_loss_run, soft_run, times_run]
    with open(f'{file_name}.pkl', 'wb') as file:
        pickle.dump(items_to_save, file)

cum_times_run = np.mean(np.cumsum(np.array(times_run), axis=1), axis=0)

plt.figure()
plt.plot(range(1, num_T+1), np.mean(np.array(train_loss_run), axis=0))
plt.plot(range(1, num_T+1),np.mean(np.array(test_loss_run), axis=0))
plt.legend(['Train (g_loss)','Test (f_loss)'])
plt.xlim([1, num_T])
plt.grid()

plt.figure()
plt.plot(cum_times_run, np.mean(np.array(train_loss_run), axis=0))
plt.plot(cum_times_run, np.mean(np.array(test_loss_run), axis=0))
plt.legend(['Train (g_loss)','Test (f_loss)'])
plt.xlim([0, cum_times_run[-1]])
plt.grid()

plt.figure()
plt.plot(range(0, num_T+1), np.mean(np.array(soft_run), axis=0))
plt.legend(['Dataset 1','Dataset 2'])
plt.xlim([0, num_T])
plt.ylim([0, 1])
plt.grid()