from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, Conv1D
from datasets import load_dataset
import numpy as np
import functools
import torch
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader
from torch.nn.utils.parametrize import register_parametrization


class LoraLayer(torch.nn.Module):

    def __init__(self, weight, r, alpha = 1, dropout_prob = 0, fan_in_fan_out = False):
        super().__init__()

        if fan_in_fan_out:
            self.in_features = weight.shape[0]
            self.out_features = weight.shape[1]
        else:
            self.in_features = weight.shape[1]
            self.out_features = weight.shape[0]   
        self.alpha = alpha
        self.fan_in_fan_out = fan_in_fan_out

        if dropout_prob > 0.:
            self.lora_dropout = nn.Dropout(p=dropout_prob)
        else:
            self.lora_dropout = nn.Identity()

        self._init_lora(r, weight_dtype=weight.dtype)

    def _init_lora(self, r, weight_dtype = None):
        # Actual trainable parameters
        if r > 0:
            if weight_dtype == None:
                weight_dtype = self.lora_A.dtype
            self.register_parameter('lora_A', nn.Parameter(torch.zeros((self.in_features, r), dtype=weight_dtype)))
            self.register_parameter('lora_B', nn.Parameter(torch.zeros((r, self.out_features), dtype=weight_dtype)))
            self.scaling = self.alpha / r
        else:
            try:
                # ensure parameters do not exist if they are zero
                delattr(self, "lora_A")
                delattr(self, "lora_B")
                delattr(self, "scaling")
            except AttributeError:
                pass
        self.r = r

    def change_lora_rank(self, new_rank):
        if new_rank != self.r:
            self._init_lora(new_rank)

    def forward(self, X):
        if self.r == 0:
            return X
        else:
            lora = self.lora_dropout(self.lora_A @ self.lora_B * self.scaling)
            if not self.fan_in_fan_out:
                lora = lora.T
            return X + lora
        
### TEST for equivalence
#import loralib
#l = torch.nn.Linear(4,10)
#register_parametrization(l, "weight", LoraLayer(l.weight, 2))

#l2 = loralib.Linear(4,10,2,merge_weights=False)
#with torch.no_grad():
#    l2.weight.copy_(l.parametrizations.weight.original.data)
#    l2.lora_A.copy_(l.parametrizations.weight[0].lora_A.T)
#    l2.lora_B.copy_(l.parametrizations.weight[0].lora_B.T)
#    l2.bias.copy_(l.bias.data)

#l.bias.data.copy_(l.bias.new_zeros(l.bias.shape))
#l2.bias.data.copy_(l2.bias.new_zeros(l2.bias.shape))

#x = torch.randn(6,4)

model_name_or_path = "gpt2-medium" # "meta-llama/Meta-Llama-3-8B"
dataset_name = "e2e_nlg"
batch_size = 8
learning_rate = 6e-3
num_train_epochs = 5
weight_decay = 0.01
seed = 0
lora_r = 4
lora_alpha = 1
lora_dropout = 0

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)

for module in model.modules():
    if type(module) == torch.nn.Linear:
        l = LoraLayer(
            weight = module.weight,
            r = lora_r,
            alpha = lora_alpha,
            dropout_prob = lora_dropout
        )
        register_parametrization(module, "weight", l)
    elif type(module) == Conv1D:
        l = LoraLayer(
            weight = module.weight,
            r = lora_r,
            alpha = lora_alpha,
            dropout_prob = lora_dropout,
            fan_in_fan_out = True
        )
        register_parametrization(module, "weight", l)

datasets = load_dataset(dataset_name)

def tokenize_function(examples):
    tokenized_texts = tokenizer(text=examples["meaning_representation"], text_target=examples["human_reference"], return_attention_mask=False)
    out = [(torch.tensor(x + y, dtype=int), torch.cat([torch.zeros(len(x), dtype=int), torch.ones(len(y), dtype=int)])) for x, y in zip(*tokenized_texts.values())]
    out = dict(zip(["input_ids", "attention_mask"], [list(x) for x in zip(*out)]))
    return out

tokenized_datasets = datasets.map(tokenize_function, batched=True, batch_size=1000)

tokenized_datasets.set_format("pt", columns=["input_ids", "attention_mask"], output_all_columns=True)

def collate_fn(batch, tokenizer):
    dict_keys = ["input_ids", "attention_mask"]
    d = {k: [dic[k] for dic in batch] for k in dict_keys}
    d['input_ids'] = torch.nn.utils.rnn.pad_sequence(d['input_ids'], batch_first=True, padding_value=tokenizer.eos_token_id)
    d['attention_mask'] = torch.nn.utils.rnn.pad_sequence(d['attention_mask'], batch_first=True, padding_value=0)
    return d

training_args = TrainingArguments(
    output_dir = '../llama3',
    learning_rate = learning_rate,
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    num_train_epochs = num_train_epochs,
    weight_decay = weight_decay,
    evaluation_strategy = 'epoch',
    save_strategy = 'epoch',
    load_best_model_at_end = True,
    seed = seed,
    label_smoothing_factor = 0.01
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    return torch.nn.functional.cross_entropy(logits, labels)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = tokenized_datasets['train'],
    eval_dataset = tokenized_datasets['validation'],
    tokenizer = tokenizer,
    data_collator = functools.partial(collate_fn, tokenizer=tokenizer),
    compute_metrics = compute_metrics
)


import IPython; IPython.embed(); exit(1)