from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model, TaskType
from datasets import Dataset
import evaluate
import torch

import pandas as pd

cer_metric = evaluate.load("cer")
wer_metric = evaluate.load("wer")
chrf_metric = evaluate.load("chrf")
rouge_metric = evaluate.load("rouge")
bleu = evaluate.load('bleu')
meteor = evaluate.load('meteor')
sacrebleu = evaluate.load("sacrebleu")

def compute_wer(pred: torch.Tensor) -> dict[str, float]:
    labels_ids = pred.label_ids
    pred_ids = pred.predictions[0]

    pred_ids[pred_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    pred_str = [p_str if p_str != "" else "$" for p_str in pred_str]
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    label_str = [l_str if l_str != "" else "$" for l_str in label_str]

    cer_score = cer_metric.compute(predictions=pred_str, references=label_str)
    wer_score = wer_metric.compute(predictions=pred_str, references=label_str)
    chrf_score = chrf_metric.compute(predictions=pred_str, references=label_str)["score"]
    chrfpp_score = chrf_metric.compute(predictions=pred_str, references=label_str, word_order=2)["score"]
    rouge_score = rouge_metric.compute(predictions=pred_str, references=label_str)["rouge1"]
    bleu_score = bleu.compute(predictions=pred_str, references=label_str)['bleu']
    sbleu_score = sacrebleu.compute(predictions=pred_str, references=label_str)['score']
    meteor_score = meteor.compute(predictions=pred_str, references=label_str)['meteor']
    metrics = {"cer": cer_score,
               "wer": wer_score, 
               "chrf": chrf_score, 
               "chrf++": chrfpp_score,
               "rouge1": rouge_score,
               "bleu": bleu_score,
               "sbleu": sbleu_score,
               "meteor": meteor_score}

    # metric [range] -> better
    # wer [0, inf) -> lower, chrf [0, 100] -> higher, chrf++ [0, 100] -> higher, rough [0, 1] -> higher
    return metrics

def preprocess_logits_for_metrics(logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels
    

train_df = pd.read_csv("train_ENG.csv")
train_df = train_df.drop(["speaker_id", "language", "formula_id", "is_tts", "pronunciation", "audio_path"], axis='columns')
train_df = train_df.fillna({"whisper_transcription": " ", "latex": "$"})
val_df = pd.read_csv("val_ENG.csv")
val_df = val_df.drop(["speaker_id", "language", "formula_id", "is_tts", "pronunciation", "audio_path"], axis='columns')
val_df = val_df.fillna({"whisper_transcription": " ", "latex": "$"})

base_model = 'google/flan-t5-large'
fine_tunned = './ckpts/seq/flan_t5'

    

model = AutoModelForSeq2SeqLM.from_pretrained(
    base_model,
    device_map="cuda:0",
    max_memory={0:"70GB"},
    trust_remote_code=True,
)

model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
    
def create_prompt(sample:dict[str, list[str]]) -> list[str]:
    full_prompt = []
    
    for i in range(len(sample['pron'])):
        # 1. so they left milton 2. so they lived milton 3. so they live milton 4. so they loved milton 5. so they kept milton: 
        current_prompt ="""
This is a pronunciation of a formula.

"""
        current_prompt += sample['pron'][i]
        current_prompt += "\nConvert this to LaTeX code.\n"
        
        full_prompt.append(current_prompt)
    model_inputs = tokenizer(sample['pron'], padding=True, truncation=True)
    labels = tokenizer(text_target=sample["latex"], padding=True, truncation=True)

    # optimzier ignores -100, so we change pads. 
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    ]
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

a = {"pron": train_df['whisper_transcription'].to_list(),
     "latex": [str(formula) for formula in train_df['latex'].to_list()]}
train_dataset = Dataset.from_dict(a)

a = {"pron": val_df['whisper_transcription'].to_list(),
     "latex": [str(formula) for formula in val_df['latex'].to_list()]}
val_dataset = Dataset.from_dict(a)
val_dataset = Dataset.from_dict(a)

train_dataset = train_dataset.map(create_prompt, num_proc=16, batched=True, batch_size=12, remove_columns=['pron', 'latex'])
val_dataset = val_dataset.map(create_prompt, num_proc=16, batched=True, batch_size=12, remove_columns=['pron', 'latex'])

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

peft_config = LoraConfig(
    lora_alpha=128,              
    lora_dropout=0.1,
    target_modules=["q", "v", "k", "o"], 
    r=64,                       
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,    
)
model = get_peft_model(model, peft_config)

training_arguments = Seq2SeqTrainingArguments(
    output_dir=fine_tunned,
    num_train_epochs=3,
    save_steps=500,
    eval_strategy="steps",
    logging_steps=500,
    save_total_limit=2,
    
    optim="adamw_torch_fused",
    per_device_train_batch_size=12,
    per_device_eval_batch_size=12,
    learning_rate=1e-4,
    weight_decay=0.001,
    warmup_ratio=0.1,
    lr_scheduler_type="linear",
    
    report_to="none",

    load_best_model_at_end=True,
    metric_for_best_model="eval_cer",
    greater_is_better=False,
    bf16=True,

    neftune_noise_alpha=0.1  # random noise for embeddings
)

trainer = Seq2SeqTrainer(
    model=model,
    data_collator=data_collator,
    tokenizer=tokenizer,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_wer,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()

trainer.model.save_pretrained(fine_tunned)