import torch
from bert_score import score
from tqdm import tqdm
import nltk
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.cider.cider import Cider
from evaluate import load
from utils import unwrap_150k_row
from rouge_score import rouge_scorer
nltk.download('punkt')
nltk.download('wordnet')  # Required for METEOR


def compute_bertscore(model, processor, val_data, max_new_tokens=64, do_sample=False, temperature=0.7, top_p=0.9):
    model.eval()
    predictions, references = [], []

    for sample in tqdm(val_data, desc="Computing BERTScore"):

        processor.patch_size = 14

        question, gt_answer, img = unwrap_150k_row(sample)
        # print(gt_answer, type(gt_answer))

        question = question + "\nASSISTANT:"

        # Prepare input
        inputs = processor(
            text=question,
            images=img,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)

        # Generate answer
        generated_tokens = model.generate(
            processor=processor,
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p
        )

        # Decode generated tokens
        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).lower().strip()

        # Store prediction and reference
        predictions.append(generated_text)
        references.append(gt_answer.lower().strip())

    # Compute BERTScore
    P, R, F1 = score(predictions, references, lang="en", model_type="roberta-large", verbose=False)
    avg_bertscore = F1.mean().item()
    print(f"Validation BERTScore (F1): {avg_bertscore:.4f}")
    return avg_bertscore

# Example BERTScore calculation
example_predictions = ["Cat is on mat"]
example_references = ["The cat is sitting on the mat"]
P, R, F1 = score(example_predictions, example_references, lang="en", model_type="roberta-large")
print(f"Example BERTScore (F1): {F1.item():.4f}")  # Expected: ~0.8-0.9 depending on model




def compute_cider(model, processor, val_data, max_new_tokens=64, do_sample=False, temperature=0.7, top_p=0.9):
    model.eval()
    predictions, references = [], []

    for sample in tqdm(val_data, desc="Computing CIDEr"):
        processor.patch_size = 14

        question, gt_answer, img = unwrap_150k_row(sample)

        question = question + "\nASSISTANT:"

        inputs = processor(
            text=question,
            images=img,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)

        generated_tokens = model.generate(
            processor=processor,
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p
        )

        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).lower().strip()

        predictions.append(generated_text)
        references.append([gt_answer.lower().strip()])  # CIDEr expects list of references

    # Format for pycocoevalcap: {image_id: [caption]}
    gts = {i: refs for i, refs in enumerate(references)}
    res = {i: [pred] for i, pred in enumerate(predictions)}

    # Compute CIDEr
    cider_scorer = Cider()
    avg_cider, cider_scores = cider_scorer.compute_score(gts, res)
    print(f"Validation CIDEr: {avg_cider:.4f}")
    return avg_cider


def compute_meteor(model, processor, val_data, max_new_tokens=64, do_sample=False, temperature=0.7, top_p=0.9):
    model.eval()
    predictions, references = [], []

    for sample in tqdm(val_data, desc="Computing METEOR"):

        processor.patch_size = 14

        question, gt_answer, img = unwrap_150k_row(sample)

        question = question + "\nASSISTANT:"

        inputs = processor(
            text=question,
            images=img,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)

        generated_tokens = model.generate(
            processor=processor,
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p
        )

        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).lower().strip()

        predictions.append(generated_text)
        references.append(gt_answer.lower().strip())

    # Compute METEOR
    meteor = load("meteor")
    meteor_scores = [meteor.compute(predictions=[pred], references=[ref])['meteor'] for pred, ref in zip(predictions, references)]
    avg_meteor = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0.0
    print(f"Validation METEOR: {avg_meteor:.4f}")
    return avg_meteor



def compute_rouge_l(model, processor, val_data, max_new_tokens=64, do_sample=False, temperature=0.7, top_p=0.9):
    model.eval()
    predictions, references = [], []

    for sample in tqdm(val_data, desc="Computing ROUGE-L"):

        processor.patch_size = 14

        question, gt_answer, img = unwrap_150k_row(sample)
        
        question = question + "\nASSISTANT:"

        inputs = processor(
            text=question,
            images=img,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)

        generated_tokens = model.generate(
            processor=processor,
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p
        )

        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).lower().strip()

        predictions.append(generated_text)
        references.append(gt_answer.lower().strip())

    # Compute ROUGE-L
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    rouge_l_scores = [scorer.score(ref, pred)['rougeL'].fmeasure for ref, pred in zip(references, predictions)]
    avg_rouge_l = sum(rouge_l_scores) / len(rouge_l_scores) if rouge_l_scores else 0.0
    print(f"Validation ROUGE-L (F1): {avg_rouge_l:.4f}")
    return avg_rouge_l

    

def compute_spice(model, processor, val_data, max_new_tokens=64, do_sample=False, temperature=0.7, top_p=0.9):
    model.eval()
    predictions, references = [], []

    for sample in tqdm(val_data, desc="Computing SPICE"):

        processor.patch_size = 14

        question, gt_answer, img = unwrap_150k_row(sample)
        
        # Prepare input for the model
        inputs = processor(
            text=question,
            images=img,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)

        # Generate output
        generated_tokens = model.generate(
            processor=processor,
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p
        )

        # Decode generated tokens
        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).lower().strip()

        # Store prediction and reference in the format required by pycocoevalcap
        predictions.append({"image_id": sample["id"], "caption": generated_text})
        references.append({"image_id": sample["id"], "caption": gt_answer.lower().strip()})

    # Compute SPICE score using pycocoevalcap
    spice_scorer = Spice()
    spice_scores, _ = spice_scorer.compute_score(
        gts={item["image_id"]: [item["caption"]] for item in references},
        res={item["image_id"]: [item["caption"]] for item in predictions}
    )

    print(f"Current validation SPICE Score: {spice_scores:.4f}")
    return spice_scores
