import pandas as pd
import torch
import gc
import json
import os
from datasets import load_dataset
from huggingface_hub import snapshot_download
from src.models.language_model import load_model_and_tokenizer


def cleanup():
    # For repeated runs on GPU, make sure to clear the cache
    # Kill any leftover CUDA memory from previous runs
    torch.cuda.empty_cache()
    gc.collect()

    if torch.cuda.is_available():
        torch.cuda.ipc_collect()


def dataset_fields(dataset_name):
    
    if "gsm8k".lower() in dataset_name.lower():
        # GSM8k Dataset
        dataset_name = "openai/gsm8k"
        dataset_subset = "main"
        dataset_split = "test"
        text_field = "question"
        answer_field = "answer"
        choices_field = None
    elif "MATH-500".lower() in dataset_name.lower():
        # Math-500 Dataset
        dataset_name = "HuggingFaceH4/MATH-500"
        dataset_subset = "default"
        dataset_split = "test"
        text_field = "problem"         # or "question" depending what the JSON key is
        answer_field = "answer"  # ground truth answer field
        choices_field = None
    elif "Humaneval".lower() in dataset_name.lower():
        # HumanEval Dataset
        dataset_name = "openai/openai_humaneval"
        dataset_subset = "openai_humaneval"
        dataset_split = "test"
        text_field = "prompt"         # or "question" depending what the JSON key is
        answer_field = "canonical_solution"  # ground truth answer field
        choices_field = None
    elif "scibench".lower() in dataset_name.lower():
        dataset_name = "xw27/scibench"
        dataset_subset = "default"
        dataset_split = "train"
        text_field = "problem_text"         # or "question" depending what the JSON key is
        answer_field = "answer_number"  # ground truth answer field
        choices_field = None
        
    local_dataset_dir = f"./local/data/{dataset_name}-{dataset_subset}"
    
    def default_prompt(example):
        return example[text_field]

    def scibench_prompt(example):
        return example[text_field] + "\n. Provide your final answer with: \boxed{[ANSWER]}"
                
    def human_eval_prompt(example):
        return "Complete the following Python function:\n\n" + example[text_field]
            
    if dataset_name in ["openai/gsm8k", "HuggingFaceH4/MATH-500"]:
        prompt_generator = default_prompt
    elif dataset_name == "openai/openai_humaneval":
        prompt_generator = human_eval_prompt
    elif dataset_name == "xw27/scibench":
        prompt_generator = scibench_prompt
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
        
    return dataset_name, dataset_subset, dataset_split, text_field, answer_field, choices_field, local_dataset_dir, prompt_generator


def get_dataset(local_dataset_dir, dataset_name, dataset_subset, dataset_split, hf_cache_dir, max_examples):
    # === Download dataset if not cached ===
    if not os.path.exists(local_dataset_dir):
        print("Downloading dataset...")
        snapshot_download(
            repo_id=f"{dataset_name}",
            repo_type="dataset",
            local_dir=local_dataset_dir,
            local_dir_use_symlinks=False,
            resume_download=True,
        )
    else:
        print("Dataset already exists locally.")

    # Load dataset from local files only
    print(f"Loading dataset from {local_dataset_dir}...")
    dataset = load_dataset(local_dataset_dir, dataset_subset, split=dataset_split, cache_dir=hf_cache_dir, trust_remote_code=True)
        
    if max_examples:
        dataset = dataset.select(range(min(max_examples, len(dataset))))
        
    # Clean up dataset to remove columns with None values
    cols_with_none = []
    for col in dataset.column_names:
        # Check if any example in this column is None
        if any(example[col] is None for example in dataset):
            cols_with_none.append(col)

    dataset = dataset.remove_columns(cols_with_none)
    
    return dataset


def get_model_and_tokenizer(local_model_dir, hf_model_name, hf_cache_dir, device):
    # === Download model if not cached ===
    if not os.path.exists(local_model_dir):
        print("Downloading model...")
        snapshot_download(
            repo_id=hf_model_name,
            local_dir=local_model_dir,
        )
    else:
        print("Model already exists locally.")

    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(
        model_name=local_model_dir,
        device=device,
        local_files_only=True,
        cache_dir=hf_cache_dir,
    )

    model.eval()
    
    return model, tokenizer


def get_model_params(model_dir):
    # Use models generation_config.json to override any parameters
    gen_config_path = os.path.join(model_dir, "generation_config.json")

    if os.path.exists(gen_config_path):
        with open(gen_config_path, "r") as f:
            raw_gen_config = json.load(f)  # dict of only what's explicitly in the file
    else:
        raw_gen_config = {}
        
    return raw_gen_config