import os
import json
import yaml
import torch
import argparse
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from huggingface_hub import login

login(token="") # insert your hugging-face token

def generate_prompt_chatml(question, tokenizer, mcq_input=None):
    """Format prompt with optional MCQ input appended to question."""
    if mcq_input:
        full_prompt = f"{question}\n\nOptions:\n{mcq_input}"
    else:
        full_prompt = question

    messages = [
        {"role": "system", "content": (
            "You are an expert physics assistant. You are given a question. "
            "Your task is to generate the final solution of the given question. "
            "Make sure the solution is mathematically accurate and correct, "
            "and at the end return the final correct option after all the intermediate steps. "
            "Let's think step by step."
        )},
        {"role": "user", "content": full_prompt}
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)


def batch_generate(model, tokenizer, questions, inputs=None, max_tokens=1024):
    prompts = [
        generate_prompt_chatml(q, tokenizer, mcq_input=inputs[i] if inputs else None)
        for i, q in enumerate(questions)
    ]
    inputs_enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to("cuda")
    with torch.no_grad():
        outputs = model.generate(
            **inputs_enc,
            max_new_tokens=max_tokens,
            do_sample=True, #False,
            temperature=0.7, #0.8 was used
            top_p=0.95, #remove top_p
            repetition_penalty=1.1, #remove repetition_penalty
            pad_token_id=tokenizer.pad_token_id
        )
    return tokenizer.batch_decode(outputs, skip_special_tokens=True)


def evaluate_dataset(
    dataset_path, tokenizer, base_model, peft_model, merged_model,
    num_samples, batch_size, save_dir):
    with open(dataset_path, "r") as f:
        data = json.load(f)

    if num_samples != -1:
        data = data[:num_samples]

    questions = []
    inputs = []
    for sample in data:
        q = sample.get("question") or sample.get("prompt")
        i = sample.get("input", None)
        questions.append(q)
        inputs.append(i)

    for i in tqdm(range(0, len(questions), batch_size), desc=os.path.basename(dataset_path)):
        batch_qs = questions[i:i + batch_size]
        batch_inps = inputs[i:i + batch_size]

        base_outputs = batch_generate(base_model, tokenizer, batch_qs, batch_inps)
        peft_outputs = batch_generate(peft_model, tokenizer, batch_qs, batch_inps)
        merged_outputs = batch_generate(merged_model, tokenizer, batch_qs, batch_inps)

        for j, (b, p, m) in enumerate(zip(base_outputs, peft_outputs, merged_outputs)):
            idx = i + j
            data[idx]["base_response"] = b
            data[idx]["dpo_response"] = p
            data[idx]["dpo_merged_response"] = m

    os.makedirs(save_dir, exist_ok=True)
    output_file = os.path.join(save_dir, os.path.basename(dataset_path).replace(".json", "_eval.json"))
    with open(output_file, "w") as f:
        json.dump(data, f, indent=2)

    print(f"✅ Saved: {output_file}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_samples", type=int, default=-1, help="Number of questions to evaluate per dataset")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size for generation")
    args = parser.parse_args()

    # Load config.yaml
    with open("configs/config_evaluate.yaml", "r") as f:
        config = yaml.safe_load(f)

    for model_name in config["model_names"]:
        try : 
            model = model_name.split("/")[-1]
            adapter_path = os.path.join(config["model_dir"], model, "lora_dpo_adapter")

            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.padding_side = "left"

            # Load base model
            base_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto").eval()

            # Load LoRA (peft) model
            peft_model = PeftModel.from_pretrained(base_model, adapter_path).eval()

            # Load merged model
            merged_model = PeftModel.from_pretrained(base_model, adapter_path)
            merged_model = merged_model.merge_and_unload().eval()

            # Dataset loading
            dataset_dir = "./Dataset"
            datasets = [f for f in os.listdir(dataset_dir) if f.endswith(".json")]
            
            save_dir_path = os.path.join("./outputs",config["save_dir"], model)
            os.makedirs(save_dir_path, exist_ok=True)

            for filename in datasets:
                dataset_path = os.path.join(dataset_dir, filename)
                evaluate_dataset(
                    dataset_path, tokenizer,
                    base_model, peft_model, merged_model,
                    num_samples=args.num_samples,
                    batch_size=args.batch_size,
                    save_dir=save_dir_path
                )

            # Free memory
            del base_model
            del peft_model
            del merged_model
            del tokenizer
            torch.cuda.empty_cache()
            print(f"✅ Finished evaluating {model_name}.")
            
        except Exception as e:
            print(f"❌ Error evaluating {model_name}: {e}")
            continue

if __name__ == "__main__":
    main()