import os 
import datasets
from datasets import Dataset

import pandas as pd
from transformers import set_seed, GenerationConfig

from src.logging import entrypoint
from src.gen_utils import generate_outputs_vllm, generate_outputs_base
from src.prompt_hub import confidence_prompts, confidence_prompts_no_reasoning


@entrypoint(with_wandb=False)
def main(
    mode: str = "vllm",
    seed: int = 0,
    log_dir: str = None,
    data_name: str = "gsm", 
    data_type: str = "test",
    model_name: str = "meta-llama/Llama-3.2-3B-Instruct",
    query_peft_dir: str = None,
    max_new_tokens: int = 2048,
    do_sample: bool = False, 
    temperature: float = 1.0,
    top_p: float = 1.0,
    batch_size: int = 1,
    c_type: str = "base", 
    template_add: bool = True,
    track_logits: bool = False,
    target_idx: list = None,
    think_mode: bool = True,
    suffix: bool = True
):
    
    set_seed(seed)        

    if data_name == "gsm":
        data = datasets.load_dataset('openai/gsm8k', 'main')[data_type]
    
    elif data_name == "mmlu":
        c_type = c_type + "_mc"
        data = pd.read_csv(f"./data/processed/mmlu_{data_type}.csv")
        data = Dataset.from_pandas(data)
        if data_type == "train":
            data = data.shuffle(seed=1129).select(range(7500))

    elif data_name == "math":
        data = pd.read_csv(f"./data/processed/math_{data_type}.csv")
        data.columns = ['question', 'answer', 'solution', 'type']
        data = Dataset.from_pandas(data)
        
    elif data_name == "arc":
        c_type = c_type + "_mc"
        data = pd.read_csv(f'./data/processed/arc_challenge_{data_type}.csv')
        data.columns = ['question', 'true_answer_texts', 'answer']
        data = Dataset.from_pandas(data)
    
    elif data_name == "hellaswag":
        c_type = c_type + "_mc"
        data = pd.read_csv(f'./data/processed/hellaswag_{data_type}.csv')
        data = Dataset.from_pandas(data)
        if data_type == "train":
            data = data.shuffle(seed=1129).select(range(7500))
    
    base_prompt = confidence_prompts[c_type] if think_mode else confidence_prompts_no_reasoning[c_type]
    questions = [d['question'] for d in data]
    
    if mode == "vllm":
        from vllm import SamplingParams 
        
        sampling_params = SamplingParams(
            max_tokens = max_new_tokens,
            temperature = 0.0 if not do_sample else temperature,
            top_p = top_p,
            seed = seed 
        )
        ##TODO: add track_logits option 
        df = generate_outputs_vllm(
            model_name,
            base_prompt, 
            questions,
            batch_size,
            sampling_params,
            skip_template=True if not template_add else False,
            suffix=suffix,
            think_mode=think_mode
            )
        df['true_answer'] = [d['answer'] for d in data]
    
    else:
        from accelerate import Accelerator
        from src.model_utils import create_model
        from src.peft_utils import get_lora_model
        
        accelerator = Accelerator()
        
        model, tokenizer = create_model(
            model_name,
            device_map="auto",
        )
    
        generation_config = GenerationConfig(
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample, 
            temperature=temperature,
            top_p=top_p
        )
        
        if query_peft_dir:
            model = get_lora_model(
                model,
                peft_id_or_dir=query_peft_dir,
                is_trainable=False,
                adapter_name="query",
            )

        model.eval()       
        df = generate_outputs_base(
            accelerator,
            model,
            tokenizer,
            base_prompt,
            data,
            batch_size,
            generation_config,
            skip_template=True if not template_add else False,
            track_logits=track_logits,
            target_idx=target_idx,
            suffix=suffix,
            think_mode=think_mode
            )
    
    os.makedirs(f"{log_dir}", exist_ok=True)
    if track_logits:
        import json 
        data_dict = {
            "conf": df['conf'].tolist(),
            "conf_logits": df['logits'].tolist(),
            "conf_idx": df["conf_target_idx"].tolist(),
        }
        with open(
            os.path.join(f"{log_dir}", data_name + '_' + data_type + '.json'), 'w') as f:
            json.dump(data_dict, f, indent=4)
    else:
        df.to_csv(
            os.path.join(
                f"{log_dir}", data_name + '_' + data_type + '.csv'
                ), 
            index=False)
    

if __name__ == "__main__":
    import fire
    fire.Fire(main)