import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from datasets import load_dataset
from PIL import Image
from IPython.display import display
from transformers import LlavaForConditionalGeneration, AutoProcessor
import traceback
import random
import numpy as np
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from utils import *
from transformers import TrainingArguments, Trainer, default_data_collator
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
from IPython.display import display
from datasets import load_dataset
import argparse
import sys

from metric_utils import *
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.cider.cider import Cider
from evaluate import load
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='name of trainning save dir of')
args = parser.parse_args()

if args.name is None:
    print("Please provide the trainning save dir by --name [name]")
    exit()    

today_dir = args.name
metric_ds_num = 2000
similarity_threshold = 0.06
think_mode = False
enhance_mode = False
text_enhance_mode = True
merged_path = "tinyllava-lora/"+today_dir+"/merged"
max_new_tokens = 128


result = {}

# Load the full merged model (LoRA already merged into base)
model = CustomLlavaForConditionalGeneration.from_pretrained(
    merged_path,
    torch_dtype=torch.float16,
    device_map="cuda",
)

# Load the processor (tokenizer + vision processor)
processor = AutoProcessor.from_pretrained(merged_path)
processor.patch_size = 14
model.set_token_mixer_processor(processor)
model.token_mixer.similarity_threshold = similarity_threshold
model.token_mixer.think_mode = think_mode
model.token_mixer.enhance_mode = enhance_mode
model.token_mixer.text_enhance_mode = text_enhance_mode
model.eval()

seed = 126 # very important to use the same seed as the training set
metric_ds_iter = load_dataset(
        "liuhaotian/LLaVA-Instruct-150K",
        split="train",
        streaming=True
    ).shuffle(seed=seed, buffer_size=1000).take(metric_ds_num)

print("###############################BLEU Score###############################")


# Example BLEU calculation
example_predictions = ["Cat is on mat"]
example_references = [["The cat is sitting on the mat"]]
example_predictions = [text.lower() for text in example_predictions]
example_references = [[ref.lower() for ref in refs] for refs in example_references]
nltk_bleu = sentence_bleu(
    references=[ref[0].split() for ref in example_references],
    hypothesis=example_predictions[0].split(),
    weights=(0.5, 0.5),
    smoothing_function=SmoothingFunction().method1
)
print(f"Example BLEU Score: {nltk_bleu:.4f}, Should be ~0.2727")


# test running validation BLEU
torch.cuda.empty_cache()
bleu_score = compute_bleu(
    model=model,
    processor=processor,
    val_data=metric_ds_iter,
    max_new_tokens=128,
    do_sample=False
)
print(f"Validation BLEUScore after LORA fine tuning {bleu_score:.4f}")

result['bleu_score'] = bleu_score

print("###############################BERT Score###############################")
# Example BERTScore calculation
example_predictions = ["Cat is on mat"]
example_references = ["The cat is sitting on the mat"]
P, R, F1 = score(example_predictions, example_references, lang="en", model_type="roberta-large")
print(f"Example BERTScore (F1): {F1.item():.4f}")  # Expected: ~0.8-0.9 depending on model

# Compute validation BERTScore
torch.cuda.empty_cache()
bertscore = compute_bertscore(
    model=model,
    processor=processor,
    val_data=metric_ds_iter,
    max_new_tokens=max_new_tokens,
    do_sample=False
)
print(f"Validation BERTScore after LORA fine tuning: {bertscore:.4f}\n")
result['bertscore'] = bertscore


print("###############################METEOR Score###############################")
# Example METEOR calculation
example_predictions = ["Cat is on mat"]
example_references = ["The cat is sitting on the mat"]
meteor = load("meteor")
meteor_score = meteor.compute(predictions=example_predictions, references=example_references)['meteor']
print(f"Example METEOR: {meteor_score:.4f}")  # Expected: ~0.3-0.4

# Compute validation METEOR
torch.cuda.empty_cache()
meteor_score = compute_meteor(
    model=model,
    processor=processor,
    val_data=metric_ds_iter,
    max_new_tokens=max_new_tokens,
    do_sample=False
)
print(f"Validation METEOR after LORA fine tuning: {meteor_score:.4f}\n")
result['meteor_score'] = meteor_score



print("###############################CIDEr Score###############################")
# Example CIDEr calculation
example_predictions = ["Cat is on mat"]
example_references = [["The cat is sitting on the mat"]]
gts = {0: example_references[0]}
res = {0: [example_predictions[0]]}
cider_scorer = Cider()
example_cider, _ = cider_scorer.compute_score(gts, res)
print(f"Example CIDEr: {example_cider:.4f}")  # Expected: ~0.2-0.5 depending on n-gram overlap

# Compute validation CIDEr
torch.cuda.empty_cache()
cider_score = compute_cider(
    model=model,
    processor=processor,
    val_data=metric_ds_iter,
    max_new_tokens=max_new_tokens,
    do_sample=False
)
print(f"Validation CIDEr after LORA fine tuning: {cider_score:.4f}\n")
result['cider_score'] = cider_score




print("###############################ROUGE Score###############################")
# Example ROUGE-L calculation
example_predictions = ["Cat is on mat"]
example_references = ["The cat is sitting on the mat"]
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
rouge_l = scorer.score(example_references[0], example_predictions[0])['rougeL'].fmeasure
print(f"Example ROUGE-L (F1): {rouge_l:.4f}")  # Expected: ~0.5-0.6

# Compute validation ROUGE-L
torch.cuda.empty_cache()
rouge_l_score = compute_rouge_l(
    model=model,
    processor=processor,
    val_data=metric_ds_iter,
    max_new_tokens=max_new_tokens,
    do_sample=False
)
print(f"Validation ROUGE-L after LORA fine tuning: {rouge_l_score:.4f}\n")
result['rouge_l_score'] = rouge_l_score




print("###############################SPICE Score###############################")
# Example SPICE calculation
example_predictions = [{"image_id": 0, "caption": "Cat is on mat"}]
example_references = [{"image_id": 0, "caption": "The cat is sitting on the mat"}]
spice_scorer = Spice()
example_spice, _ = spice_scorer.compute_score(
    gts={item["image_id"]: [item["caption"]] for item in example_references},
    res={item["image_id"]: [item["caption"]] for item in example_predictions}
)
print(f"Example SPICE Score: {example_spice:.4f}")


# Ensure data has 'id' field; if not, add synthetic IDs
spice_val_data = metric_ds_iter.map(lambda x, idx: {**x, "id": idx}, with_indices=True)

# Compute validation SPICE
torch.cuda.empty_cache()
spice_score = compute_spice(
    model=model,
    processor=processor,
    val_data=spice_val_data,
    max_new_tokens=max_new_tokens,
    do_sample=False
)
print(f"Validation SPICE Score before LoRA fine-tuning: {spice_score:.4f}")
result['spice_score'] = spice_score


print(result)

