import torch
from tqdm import tqdm
from rouge_score import rouge_scorer
import numpy as np
import random
import datasets

def run_generation(input_ids, model, tokenizer, max_new_tokens):
    model_family = 'llama2-7b'
    model.eval()
    input_strings = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    if model_family == 'llama2-7b':
        split_symbol = " [/INST]"
    elif model_family == 'qwen3-8b':
        split_symbol = "assistant\n"
    else:
        raise NotImplementedError(f"Model family {model_family} not supported.")
    gt = [s.split(split_symbol)[1] for s in input_strings]
    input_strings = [s.split(split_symbol)[0] for s in input_strings]
    # add ["/INST "] to the end of each string
    if model_family in ('llama2-7b', 'qwen3-8b'):
        input_strings = [s + split_symbol for s in input_strings]
        
    #now tokenize the strings with left padding
    left_pad_tokenizer = tokenizer
    left_pad_tokenizer.padding_side = 'left'
    left_pad_tokenizer.padding_size = 'longest'
    left_pad_tokenizer.pad_token = left_pad_tokenizer.eos_token
    left_pad_tokenizer.pad_token_id = left_pad_tokenizer.eos_token_id
    
    inputs = left_pad_tokenizer.batch_encode_plus(input_strings, add_special_tokens=True, return_tensors='pt', padding=True).to(model.device)
    # now generate
    out = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=53,
            do_sample=True,
            temperature=0.7,     # add randomness
            top_p=0.9,           # nucleus sampling
            pad_token_id=left_pad_tokenizer.eos_token_id  # avoids warning
        )
    strs = left_pad_tokenizer.batch_decode(out[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)
    return input_strings, strs, gt

def eval_gen(
    model, tokenizer,
    data, 
    max_new_tokens: int = 100,
    max_samples : int = 512
):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    rouge1_recall = {}
    rougeL_recall = {}
    gen_outputs = []
    ground_truths = []
    input_strings = []
    all_indices = []
    num_samples = 0
    
    for sample in tqdm(data):
        input_ids, label, attention_mask, idx = sample

        all_indices.extend(idx.cpu().numpy().tolist())
        with torch.no_grad():
            input_string, gen_output, gt = run_generation(input_ids, model, tokenizer=tokenizer, max_new_tokens=max_new_tokens) 
            gen_outputs.extend(gen_output)
            ground_truths.extend(gt)
            input_strings.extend(input_string)
        num_samples += len(input_ids)
        
    for gen, gt, idx in zip(gen_outputs, ground_truths, all_indices):
        rouge_scores = scorer.score(gt, gen)
        rouge1_recall[idx] = rouge_scores['rouge1'].recall
        rougeL_recall[idx] = rouge_scores['rougeL'].recall

    return {'rouge1_recall': rouge1_recall, 'rougeL_recall': rougeL_recall}


def aggregate_results(eval_result_dict):
    eval_task_dict = {
        'eval_rouge.json': 'Retain',
        'eval_rouge_forget.json': 'Forget',
        'eval_rouge_test.json': 'Test'
    }
    eval_tasks = list(eval_task_dict.keys())
    metrics = ['ROUGE']

    output_result = {}
    for eval_task in eval_tasks:
        for metric in metrics:
            output_result[metric + ' ' + eval_task_dict[eval_task]] = []

    # k is different files
    for k, v in eval_result_dict.items():
        # getting ROUGE
        avg_rouge = np.array(list(eval_result_dict[k]['rougeL_recall'].values())).mean()
        output_result[f'ROUGE {eval_task_dict[k]}'] = avg_rouge
    return output_result
