import fishfarm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from fishfarm.tasks.competation_math import (
    LatexFormatMathTask,
    MathSample,
    last_boxed_only_string,
    remove_boxed,
)
import datasets

import sys
sys.path.append('.')
from utils import get_vllm_model, extract_ans, forward, load_hf_params_to_vllm


MODEL_ID = 'meta-llama/Meta-Llama-3-8B-Instruct'
DECOMPOSED_PARAM_FILE = 'llama3_decomposed_params.pt'
USE_DISPATCHER = False


system_message = """
# Analyze the given question and classify it into one of four categories: 'code', 'math', 'reasoning' or 'other'. Follow these guidelines:

1. Code: Questions asking for programming solutions, functions, algorithms. Often includes specific programming terms, language syntax, or data structures.
2. Math: Questions involving mathematical calculations, formulas, statistics. Often includes numbers, equations, or references to mathematical operations.
3. Reasoning: Questions requiring logical thinking, application of scientific knowledge, or critical analysis of information. Often presents statements that need evaluation based on general understanding. 
4. Other: Questions not clearly fit into above categories.
 
Instructions:
- Consider the primary focus, skills, and knowledge required to answer the question.
- If a question spans multiple categories, choose the most dominant one.
- Provide your final classification within \\boxed{} notation. Example: \\boxed{reasoning}

Format your response as follows:
Classification: \\boxed{category}
"""

dataset = datasets.load_dataset("hendrycks/competition_math", "main", split="test")
math_samples = []
for sample in dataset:
    answer = remove_boxed(last_boxed_only_string((sample["solution"])))
    math_samples.append(MathSample(problem=sample["problem"], answer=answer, type=sample["type"]))

vllm_model = get_vllm_model(MODEL_ID, 2)
vllm_model.chat_template = None # use the default

if USE_DISPATCHER:
    model_dir = "trained_dispacher_path"
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, torch_dtype=torch.bfloat16)
    base_params = model.state_dict()
    decomposed_params = torch.load(DECOMPOSED_PARAM_FILE,weights_only=True)
    model_filepath = f'{model_dir}/learnable_params.pt'
    learnable_params = torch.load(model_filepath)

    print('Learnable params loaded.')
    new_params = forward(
        model, base_params, decomposed_params, learnable_params
    )     
    load_hf_params_to_vllm(new_params, vllm_model.llm)   
    del new_params    

# classfied the samples
requests = []
for sample in math_samples:
    messages = list([fishfarm.Message("system", system_message)])
    messages.append(fishfarm.Message(role="user", content=sample.problem))
    requests.append(fishfarm.models.GenerationRequest(messages=messages))

classfied_mmlu_samples = []
for sample, result in zip(math_samples, vllm_model.generate(requests)):
    output = result.generation
    prediction = extract_ans(output)
    sample.expert_label = prediction
    classfied_mmlu_samples.append(sample)

# print the percentage of classfication
labels = ['other', 'math', 'code', 'reasoning']
for label in labels:
    cnt = 0
    for s in classfied_mmlu_samples:
        if s.expert_label == label:
            cnt += 1
    print(f"label {label} has {cnt} samples, perc is {cnt / len(classfied_mmlu_samples):.2f}")

# start computing the acc of each category
SYSTEM_MSG = (
    "Solve the question below by reasoning step by step,"
    "and put the final answer within \\boxed{}."
)

vllm_model.chat_template = fishfarm.chat_templates.LLAMA3

# get the svd loaded result
if not USE_DISPATCHER:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    base_params = model.state_dict()
    decomposed_params = torch.load(DECOMPOSED_PARAM_FILE,weights_only=True)

# get the expert vectors result
for label in labels:
    cur_samples = [s for s in classfied_mmlu_samples if s.expert_label == label]
    if len(cur_samples) == 0:
        continue

    # load the corrlated params
    model_dir = None
    if label == 'math':
        model_dir = "trained_math_z_expert_vector"
    elif label == 'code':
        model_dir = "trained_code_z_expert_vector"
    elif label == 'reasoning':
        model_dir = "trained_reasoning_z_expert_vector"
    elif label == 'other':
        pass
    else:
        raise ValueError

    if model_dir is not None:
        model_filepath = f'{model_dir}/learnable_params.pt'
        learnable_params = torch.load(model_filepath)
        new_params = forward(
            model, base_params, decomposed_params, learnable_params
        )     
        load_hf_params_to_vllm(new_params, vllm_model.llm)   
            
    test_eval = LatexFormatMathTask(
        samples=cur_samples,
        context_messages=[
            fishfarm.Message("system", SYSTEM_MSG),
        ],
    )        
    result = test_eval.evaluate(vllm_model)
    print(f" {label} {result.aggregate_metrics}, with has num samples {len(cur_samples)}")
    
    del learnable_params
    del new_params