import os
import time
import json
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter
from typing import List, Dict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from utils import read_json, write_json, get_keywords, get_alphabet_choice, remove_boxed, last_boxed_only_string, is_math_equiv, _strip_string
from math_verify import parse, verify
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
import contextlib
import gc
import subprocess
import math
import ast
MAX_TOKENS = 32768
N_SAMPLES = 1
TEMPERATURE = 1.0
THRESHOLD = 4000
TOP_P = 1.0

agent_map = {
    "Llama": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.1-8B-Instruct/snapshots/0e9e39f249a16976918f6564b8830bc894c89659/", # meta-llama/Meta-Llama-3.1-8B-Instruct
    "Qwen": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-7B-Instruct/snapshots/a09a35458c702b33eeacc393d103063234e8bc28/", # Qwen/Qwen2.5-7B-Instruct
    "Mistral": "/datasets/ai/mixtral/hub/models--mistralai--Mistral-Nemo-Instruct-2407/snapshots/8aedd450f2583e9c67fae1929f6936b8fc5aef9c/", # mistralai/Mistral-Nemo-Instruct-2407
    "Phi": "/datasets/ai/phi/hub/models--microsoft--Phi-3.5-mini-instruct/snapshots/3145e03a9fd4cdd7cd953c34d9bbf7ad606122ca/", # microsoft/Phi-3.5-mini-instruct
    "Gemma": "/datasets/ai/gemma/hub/models--google--gemma-2-9b-it/snapshots/1937c70277fcc5f7fb0fc772fc5bc69378996e71/", # google/gemma-2-9b-it
    "GLM": "/datasets/ai/glm/hub/models--THUDM--glm-4-9b-chat/snapshots/bd8234fe5e0c09c48637a92abb0c797cb5fa0e73/", # THUDM/glm-4-9b-chat
    "Exaone": "/datasets/ai/lg/hub/models--LGAI-EXAONE--EXAONE-3.5-7.8B-Instruct/snapshots/0ff6b5ec7c13b049b253a16a889aa269e6b79a94/", # LGAI-EXAONE/EXAONE-3.5-7.8B-Instruct
    "Granite": "/datasets/ai/ibm-granite/hub/models--ibm-granite--granite-3.1-8b-instruct/snapshots/3f05a1d007b2484bbf17593efe110bd5b9d67655/", # ibm-granite/granite-3.1-8b-instruct
    "QwenMath": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d/", # Qwen/Qwen2.5-Math-7B
    "QwenCode": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-Coder-7B-Instruct/snapshots/c03e6d358207e414f1eca0bb1891e29f1db0e242/", #"Qwen/Qwen2.5-Coder-7B-Instruct",
    "DeepSeekMath": "/datasets/ai/deepseek/hub/models--deepseek-ai--deepseek-math-7b-instruct/snapshots/0a5828f800a36df0fd7f0ed581b983246c0677ff/", # deepseek-ai/deepseek-math-7b-instruct
    "QwenR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/6602cadec947dbb53e64f3d8d6425320b2197247/", # deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
    "LlamaR1": "/datasets/ai/deepseek/hub/models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B/snapshots/74fbf131a939963dd1e244389bb61ad0d0440a4d/", # deepseek-ai/DeepSeek-R1-Distill-Llama-8B
    "InternLM": "/datasets/ai/internlm/hub/models--internlm--internlm3-8b-instruct/snapshots/28c99415adaf61767bd1c619f4f99f308fdfd223/", # internlm/internlm3-8b-instruct
    "Mathstral": "/datasets/ai/mixtral/hub/models--mistralai--Mathstral-7B-v0.1/snapshots/b6408c37979d6805935973ab06468089ff72ce95/", # mistralai/Mathstral-7B-v0.1
    "BioLlama": "/datasets/ai/contactdoctor/hub/models--ContactDoctor--Bio-Medical-Llama-3-8B/snapshots/b42b41f30767e43b6a636490bde14d82e5bad0c1/", # ContactDoctor/Bio-Medical-Llama-3-8B  
    "Qwen72B": "/datasets/ai/qwen/hub/models--Qwen--Qwen2.5-72B-Instruct/snapshots/495f39366efef23836d0cfae4fbe635880d2be31/", # Qwen/Qwen2.5-72B-Instruct
    "Llama70B": "/datasets/ai/llama3/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/", # meta-llama/Llama-3.3-70B-Instruct
    "NLlama": "/datasets/ai/nvidia/hub/models--nvidia--Llama-3.1-Nemotron-Nano-8B-v1/snapshots/a22e1c57330633cd3522903f9bb82480bf3192a6/", # nvidia/Llama-3.1-Nemotron-Nano-8B-v1
}

MODELS = [
    "Llama", #
    "Qwen", #
    "Gemma", #
]

task_to_weights = {
    "MATH500": {"Llama": 0.5464, "Qwen": 0.7528, "Gemma": 0.5179},
    # "MATH500": {"Qwen": 1, "Gemma": 1, "Llama": 1},
    # "AIME24": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
    "AIME24": {"NLlama": 0.6921, "QwenR1": 0.5367},
    "AIME25": {"Llama": 0.065, "Qwen": 0.1243, "Gemma": 0.0141},
    "GSM8K": {"Llama": 0.8723, "Qwen": 0.9445, "Gemma": 0.9133},
    "MMLU_Pro": {"Llama": 0.3457, "Qwen": 0.5343, "Gemma": 0.5314},
}

os.environ["OMP_NUM_THREADS"] = "20"

from typing import List, Tuple, Optional

def parse_args():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--task',
        type=str
    )
    parser.add_argument(
        '--gpus',
        type=int
    )
    parser.add_argument(
        '--seed',
        type=int,
        default=42
    )
    parser.add_argument(
        '--n_sampling',
        type=int,
        default=5,
        help="Number of samples to get from each model"
    )
    parser.add_argument(
        '--consistency_threshold',
        type=float,
        default=1.0,
        help="Consistency threshold (0 < t <= 1) to determine if a switch is needed"
    )
    parser.add_argument(
        '--voting_strategy',
        type=str,
        default='consistency',
        choices=['consistency', 'ms_md', 'ms_mad'],
        help="Voting strategy to use for final prediction: consistency (original), ms_md (most frequent), ms_mad (advanced decision)"
    )
    parser.add_argument(
        '--models',
        type=str,
        default='',
        help="Models to use for prediction"
    )
    return parser.parse_args()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def get_model_responses(task, agent, n_gpus, indices, n_sampling, is_reasoning=False):
    
    train_samples = read_json(f"./Datasets/test/{task}_test.json")
    train_samples = [train_samples[i] for i in indices]
    if task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
        train_prompts_reasoning = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in train_samples
        ]

        train_prompts_non_reasoning = [
            f"Question: {sample['question']}\n"
            f"Directly provide the correct answer choice without explanation, formatted as: \"The answer is \\boxed{{X}}\", "
            f"where X is the final answer, at the end of your response."
            for sample in train_samples
        ]
    else:
        train_prompts_reasoning = [
            f"Question: {sample['question']}\n"
            f"Provide your step-by-step reasoning first, and then print \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in train_samples
        ]

        train_prompts_non_reasoning = [
            f"Question: {sample['question']}\n"
            f"Directly provide the correct answer choice without explanation, formatted as: \"The answer is (X)\", "
            f"where X is the answer choice (one capital letter), at the end of your response."
            for sample in train_samples
        ]

    model_id = agent_map.get(agent)
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    messages_reasoning = []
    messages_non_reasoning = []
    for i, p in enumerate(train_prompts_reasoning):
        msg_reasoning = [{"role": "user", "content": p}]
        msg_reasoning = tokenizer.apply_chat_template(
            msg_reasoning,
            tokenize=False,
            add_generation_prompt=True
        )
        messages_reasoning.append(msg_reasoning)
        msg_non_reasoning = [{"role": "user", "content": train_prompts_non_reasoning[i]}]
        msg_non_reasoning = tokenizer.apply_chat_template(
            msg_non_reasoning,
            tokenize=False,
            add_generation_prompt=True
        )
        messages_non_reasoning.append(msg_non_reasoning)
    
    if agent in ["Phi", "Mistral"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 16000,
                  tensor_parallel_size = n_gpus,
                    trust_remote_code = True)
    elif agent in ["DeepSeekMath"]:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  max_model_len = 4096,
                  tensor_parallel_size = n_gpus,
                  trust_remote_code = True)        
    else:
        llm = LLM(model = model_id,
                  download_dir = "/nas-ssd2/cychen/saved_models",
                  tensor_parallel_size = n_gpus,
                  trust_remote_code = True)
        
    sampling_params = SamplingParams(temperature=TEMPERATURE, max_tokens=MAX_TOKENS, n=n_sampling, top_p=TOP_P)
    if is_reasoning:
        outputs_reasoning = llm.generate(messages_reasoning, sampling_params)
        responses_reasoning = [[val.text for val in output.outputs] for output in outputs_reasoning]
        n_tokens_reasoning = [[len(val.token_ids) for val in output.outputs] for output in outputs_reasoning]
    else:
        outputs_non_reasoning = llm.generate(messages_non_reasoning, sampling_params)
        responses_non_reasoning = [[val.text for val in output.outputs] for output in outputs_non_reasoning]
        n_tokens_non_reasoning = [[len(val.token_ids) for val in output.outputs] for output in outputs_non_reasoning]

    destroy_model_parallel()
    destroy_distributed_environment()
    del llm.llm_engine.model_executor
    if hasattr(llm, 'engine'):
        del llm.engine

    if 'llm' in locals():
        del llm
    gc.collect()
    torch.cuda.empty_cache()
    import torch.distributed as dist
    if dist.is_initialized():
        dist.destroy_process_group()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    time.sleep(10)
    subprocess.run(["nvidia-smi"], check=True)

    if is_reasoning:
        return responses_reasoning, n_tokens_reasoning
    else:
        return responses_non_reasoning, n_tokens_non_reasoning

def normalize_math_string(s: str) -> str:
    """
    Apply your entire cleaning pipeline so that
    things like `\frac12`, `a/b`, `sqrt3`, units, etc.
    all get put into a canonical LaTeX form.
    """
    # assume you have all those helpers in scope:
    # _strip_string calls _fix_sqrt, _fix_fracs, _fix_a_slash_b, _remove_right_units, etc.
    return _strip_string(s)

def is_math_equiv_normalized(a: str, b: str) -> bool:
    """
    Normalize both sides first, then fall back to your verify/parse checks.
    """
    a_norm = normalize_math_string(a)
    b_norm = normalize_math_string(b)
    # print(a, b)

    try:
        # wrap in $…$ or not, in case one uses parentheses
        return any([
            verify(parse(f"${a_norm}$"), parse(f"${b_norm}$")),
            verify(parse(a_norm), parse(b_norm)),
            verify(parse(a_norm), parse(b_norm.replace("\\(", "").replace("\\)", ""))),
        ])
    except:
        return False

# most_common_math(['\\text{Evelyn}', 'Evelyn'])
    
def most_common_math(
    answers: List[str],
    n: Optional[int] = None
) -> List[Tuple[str,int]]:
    """
    Count and group answers up to math-equivalence *after* normalization.
    Returns a list of (representative, count), sorted descending.
    """
    groups: List[Tuple[str,int]] = []

    for ans in answers:
        placed = False
        for i, (rep, cnt) in enumerate(groups):
            if is_math_equiv_normalized(rep, ans):
                groups[i] = (rep, cnt + 1)
                placed = True
                break
        if not placed:
            # start a new group; use the normalized form as the rep
            norm = normalize_math_string(ans)
            groups.append((norm, 1))

    groups.sort(key=lambda x: x[1], reverse=True)
    return groups if n is None else groups[:n]

def calculate_consistency_ratio(answers):
    """Calculate consistency ratio: most common / number of all current answers"""
    if not answers:
        return 0.0
    # counter = Counter(answers)
    # most_common_count = counter.most_common(1)[0][1]
    most_common_count = most_common_math(answers)[0][1]
    return most_common_count / len(answers)

def ms_md_voting(all_answers):
    """
    MS_MD (Model Switch - Most Frequent) voting strategy,
    but counting math‐equivalent answers as identical.
    """
    if not all_answers:
        return ""
    
    # Flatten and filter
    flat = [
        ans for model_answers in all_answers
            for ans in model_answers
            if ans not in ("Error", "")
    ]
    if not flat:
        return ""
    
    # Normalize every answer to canonical LaTeX
    norm_flat = [normalize_math_string(ans) for ans in flat]
    
    # Use the math-equivalence aware counter
    top = most_common_math(norm_flat, n=1)
    return top[0][0] if top else ""


def ms_mad_voting(all_answers, model_weights=None):
    """
    MS_MAD (Model Switch - Advanced Decision) voting strategy,
    now normalizing every answer and counting math‐equivalent ones together.
    """
    if not all_answers:
        return ""
    
    # Step 1: clean & normalize answers for each model
    per_model: List[List[str]] = []
    for answers in all_answers:
        cleaned = [
            normalize_math_string(ans.replace(" ", ""))
            for ans in answers
            if ans not in ("Error", "")
        ]
        per_model.append(cleaned)
    
    if not any(per_model):
        return ""
    
    # Step 2: default equal weights if none provided
    if model_weights is None:
        model_weights = {str(i+1): 1.0 for i in range(len(per_model))}
    
    # Step 3: internal consistency → weighted per model
    def internal_consistency_score(normed_answers: List[str]):
        """
        Calculate internal consistency score, but first merge answers
        that are math-equivalent according to is_math_equiv_normalized().
        Returns a dict mapping each representative answer to its weighted count.
        """
        # 1. Group by math-equivalence
        groups: List[Tuple[str, int]] = []
        for ans in normed_answers:
            placed = False
            for i, (rep, cnt) in enumerate(groups):
                if is_math_equiv_normalized(rep, ans):
                    groups[i] = (rep, cnt + 1)
                    placed = True
                    break
            if not placed:
                groups.append((ans, 1))
        
        # 2. Build a Counter-like dict
        cnts = {rep: cnt for rep, cnt in groups}
        total = sum(cnts.values())
        if total == 0:
            return {}
        
        # 3. Compute entropy over these grouped counts
        ent = -sum((c/total) * math.log2(c/total) for c in cnts.values())
        max_ent = math.log2(len(cnts)) if cnts else 0
        
        # 4. Bias term to favor small sample sizes less aggressively
        bias = 1.0 / total
        weight = (bias + (1 - bias) * (1 - ent/max_ent)) if max_ent > 0 else 1.0
        
        # 5. Return weighted scores
        return {rep: cnt * weight for rep, cnt in cnts.items()}
    
    # Step 4: build weighted score dicts
    weighted_list = []
    for i, answers in enumerate(per_model):
        i_scores = internal_consistency_score(answers)
        w = model_weights.get(str(i+1), 1.0)
        weighted_list.append({ans: sc * w for ans, sc in i_scores.items()})
    
    # Step 5: aggregate across all answers
    # build equivalence classes over all normalized answers
    equiv_groups: List[Tuple[str, List[str]]] = []
    for ans in {a for lst in per_model for a in lst}:
        placed = False
        for i, (rep, members) in enumerate(equiv_groups):
            if is_math_equiv_normalized(rep, ans):
                members.append(ans)
                placed = True
                break
        if not placed:
            equiv_groups.append((ans, [ans]))
    
    # now sum weighted scores across each class
    total_scores: Dict[str, float] = {}
    for rep, members in equiv_groups:
        # use a set so exact duplicates are only counted once
        unique_members = set(members)
        total_scores[rep] = sum(
            weighted_scores.get(mem, 0)
            for weighted_scores in weighted_list
            for mem in unique_members
        )
    
    if not total_scores:
        return ""
    
    # Step 6: pick highest
    max_score = max(total_scores.values())
    winners = [ans for ans, sc in total_scores.items() if sc == max_score]
    # break ties arbitrarily by first seen
    return winners[0], total_scores

if __name__ == "__main__":
    args = parse_args()
    seed_everything(args.seed)

    num_choice = 10 if args.task == "MMLU_Pro" else 4
    test_samples = read_json(f"./Datasets/test/{args.task}_test.json")

    # Get the list of models in order from best to worst based on task_to_weights
    # available_models = list(task_to_weights[args.task].keys())
    # Sort by weight (best to worst)
    # available_models.sort(key=lambda x: task_to_weights[args.task][x], reverse=True)
    available_models = ast.literal_eval(args.models)

    # Initialize data structures to store all answers for each sample
    all_answers = [[] for _ in range(len(test_samples))]
    all_predictions = [[] for _ in range(len(test_samples))]
    models_used = [[] for _ in range(len(test_samples))]
    final_predictions = [None for _ in range(len(test_samples))]
    num_llm_calls = 0

    # Start with all samples
    current_indices = np.arange(len(test_samples))
    
    # Iterate over models (like early_stop.py)
    for model_idx, model in enumerate(available_models):
        if len(current_indices) == 0:
            break
            
        print(f"Processing model {model} ({model_idx + 1}/{len(available_models)}) with {len(current_indices)} samples")
        
        # Get responses from current model for remaining samples
        reasoning_responses, n_tokens_reasoning = get_model_responses(args.task, model, args.gpus, current_indices, args.n_sampling, is_reasoning=True)
        num_llm_calls += len(current_indices) * args.n_sampling
        
        # Extract final answers using same method as early_stop.py
        final_answers = []
        for i, responses in enumerate(reasoning_responses):
            if args.task in ["MATH500", "AIME24", "AIME25", "GSM8K"]:
                current_final_answers = [remove_boxed(last_boxed_only_string(response)) for response in responses]
            else:
                current_final_answers = [get_alphabet_choice(response, num_choice=num_choice) for response in responses]
            final_answers.append(current_final_answers)
        
        # Update all_answers and all_predictions for current indices
        for i, (sample_idx, curr_responses, curr_final_answers) in enumerate(zip(current_indices, reasoning_responses, final_answers)):
            all_answers[sample_idx].append(curr_final_answers)
            all_predictions[sample_idx].append(curr_responses)
            models_used[sample_idx].append(model)
        # Calculate which samples need more models (consistency < threshold)
        next_indices = []
        for i, sample_idx in enumerate(current_indices):
            consistency_ratio = calculate_consistency_ratio(all_answers[sample_idx][-1])
            if consistency_ratio < args.consistency_threshold:
                next_indices.append(sample_idx)
            else:
                final_predictions[sample_idx] = all_answers[sample_idx][-1][0]
        
        current_indices = np.array(next_indices)
        print(f"After model {model}: {len(current_indices)} samples still need more models")
    
    # Calculate final predictions using chosen voting strategy
    for j, curr_used_models in enumerate(models_used):
        if len(curr_used_models) == len(available_models):
            if args.voting_strategy == 'consistency' or args.voting_strategy == 'ms_md':
                if all_answers[i]:
                    counter = Counter([ans for each_model_answers in all_answers[i] for ans in each_model_answers])
                    final_predictions[i] = counter.most_common(1)[0][0]
                else:
                    final_predictions[i] = ""
            # elif args.voting_strategy == 'ms_md':
            #     # MS_MD voting (most frequent across all models)
            #     for answers in all_answers:
            #         final_predictions.append(ms_md_voting([answers]))
            elif args.voting_strategy == 'ms_mad':
                # MS_MAD voting (advanced decision with weights)
                # Create weights based on task_to_weights for the current task
                model_weights = {}
                for i, model in enumerate(curr_used_models):
                    if model in task_to_weights[args.task]:
                        model_weights[str(i + 1)] = task_to_weights[args.task][model]
                    else:
                        model_weights[str(i + 1)] = 1.0  # default weight
                final_predictions[j] = ms_mad_voting(all_answers[j], model_weights)[0]
    
    # Calculate accuracy
    df = pd.DataFrame(test_samples)
    correctness = [is_math_equiv(final_answer, str(df['gold_answer'][i])) for i, final_answer in enumerate(final_predictions)]
    acc = round(sum(correctness) / len(correctness) * 100, 2)
    
    # Save results
    df['pred'] = final_predictions
    df['correctness'] = correctness
    df['all_answers'] = all_answers
    df['all_predictions'] = all_predictions
    
    # Create output directory
    # os.makedirs(f"./Results/skills/{args.task}", exist_ok=True)
    
    df.to_csv(f"./Results/skills/{args.task}/model_switch_seed{args.seed}_budget{args.n_sampling*len(available_models)}_{round(num_llm_calls / len(test_samples), 2)}_consistency{args.consistency_threshold}_voting{args.voting_strategy}_acc{acc}.csv", index=False)
    
    print(f"ModelSwitch completed with accuracy: {acc}%")
    print(f"Voting strategy used: {args.voting_strategy}")
    print(f"Average LLM calls per sample: {num_llm_calls / len(test_samples):.2f}")
    print(f"Total LLM calls: {num_llm_calls}")
    print(f"Models used: {available_models}") 