import os 
import re
from vllm import SamplingParams, LLM 

import numpy as np
import pandas as pd
from tqdm import tqdm 
from transformers import set_seed

from src.prompt_hub import parsing_prompts, system_prompts
from src.eval_utils import (extract_hash_answer, 
                            extract_xml_answer, 
                            extract_number_xml_confidence)
from src.model_utils import create_tokenizer


def batch_generate(
    input_texts, 
    model,
    sampling_params,
    batch_size=64
    ):
    results = []
    for i in tqdm(range(0, len(input_texts), batch_size)):
        
        if i + batch_size > len(input_texts):
            batch_texts = input_texts[i:]
        
        else:
            batch_texts = input_texts[i:i+batch_size]
        
        outputs = model.generate(
            batch_texts,
            sampling_params,
            use_tqdm=False)
        outputs = [outputs[i].outputs[0].text for i in range(len(outputs))]
        results.extend(outputs)
    return results


def calculate_ece(y_tf, y_prob, n_bins=10):
    y_tf = y_tf
    y_prob = [0.01*p for p in y_prob]
    
    y_tf = np.array(y_tf)
    y_prob = np.array(y_prob)
    
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    bin_indices = np.digitize(y_prob, bin_edges, right=True) -1 
    
    ece = 0.0
    n = len(y_tf)
    
    for i in range(n_bins):
        bin_mask = (bin_indices == i)
        bin_count = np.sum(bin_mask)
        
        if bin_count > 0:
            avg_confidence = np.mean(y_prob[bin_mask])
            accuracy = np.mean(y_tf[bin_mask])
            ece += (bin_count / n) * abs(avg_confidence - accuracy)

    return ece


def main(
    seed: int = 0,
    log_dir: str = None,
    data_name: str = "gsm", 
    data_path: str = None,
    model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
    batch_size: int = 1,
):
    
    set_seed(seed)        
    
    base_data = pd.read_csv(data_path)
    parsing_prompt = parsing_prompts[data_name] 
    base_prompt = system_prompts["qwen" if "qwen" in model_name.lower() else "llama"]
    
    questions = [parsing_prompt.replace("<model_answer>", 
                                     str(extract_xml_answer(base_data.iloc[i]['pred_answer']))) 
                 for i in range(len(base_data))]
    
    if data_name == "math":
        questions = [q.replace("<true_answer>", 
                               str(base_data.iloc[i]['true_answer'])) 
                     for i, q in enumerate(questions)]
        
    sampling_params = SamplingParams(
        max_tokens = 30,
        temperature = 0.0,
        seed = seed 
    )
    
    tokenizer = create_tokenizer(model_name)
    model = LLM(model=model_name)
    
    inputs = [tokenizer.apply_chat_template(
        [{"role": "system", "content": base_prompt},
        {"role": "user", "content": q}], 
        tokenize=False, add_generation_prompt=True) for q in questions]
        
    outputs = batch_generate(
        inputs,
        model,
        sampling_params,
        batch_size=batch_size
    )
    
    '''
    if data_name == "gsm":
        pred_answers = [a.split('**Model\'s Final Answer is:**')[-1].split('\n')[0] for a in outputs]
        true_answers = [re.sub(r'[^0-9.]', '', extract_hash_answer(base_data.iloc[i]['true_answer']))
                    for i in range(len(base_data))]
        tf = [1 if str(g) == str(p) else 0 for g, p in zip(true_answers, pred_answers)]
    elif data_name == "math":
        tf = [1 if "yes" in p.lower() else 0 for p in outputs]
            
        prob = [extract_number_xml_confidence(base_data.iloc[i]['pred_answer'])
                for i in range(len(base_data))]
        
        ece = calculate_ece(tf, prob)
        
        print(f"ACC: {np.mean(tf):.4f}")
        print(f"ECE: {ece:.4f}")
    '''
    
    os.makedirs(f"{log_dir}", exist_ok=True)
    base_data['eval_result'] = outputs
    #base_data['tf'] = tf
    
    base_data.to_csv(
        os.path.join(
            f"{log_dir}", data_name + '_parsing' + '.csv'
        ),
        index=False
    )

if __name__ == "__main__":
    import fire
    fire.Fire(main)