#!/usr/bin/env python3
"""
Multi-GPU evaluation script for MMLUFilteredChatDataset
Splits the dataset across available GPUs and computes overall accuracy
"""

import argparse
import os
import re
import json
import torch
import torch.nn as nn
import torch.multiprocessing as mp
from collections import defaultdict
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Any, List
from transformers import AutoModelForCausalLM, AutoTokenizer
from datetime import datetime
import sys

from rosetta.train.dataset_adapters import MMLUFilteredChatDataset, MMLUChatDataset
from rosetta.model.projector import create_projector, load_projector
from rosetta.model.wrapper import RosettaModel
from rosetta.model.aggregator import WeightedAggregator, load_aggregator

try:
    from transformers.cache_utils import DynamicCache
except ImportError:
    DynamicCache = None 


parser = argparse.ArgumentParser(description='MMLUFilteredChatDataset Evaluation')
parser.add_argument("--config", type=str, default="recipe/default_config.json", help="Path to JSON config file")
parser.add_argument("--data_path", type=str, default="teacher_datasets/mmlu_qwen2.5_math_filtered", help="Path to MMLUFiltered dataset")
parser.add_argument("--split", type=str, default="test", help="Dataset split to evaluate")
parser.add_argument("--num_samples", type=int, default=None, help="Number of samples to evaluate (None for all)")
parser.add_argument("--sample_interval", type=int, default=1, help="Sample every N examples")
args = parser.parse_args()

with open(args.config, "r") as f:
    cfg: Dict[str, Any] = json.load(f)

# Extract configuration sections
model_config = cfg["model"]
output_config = cfg["output"]
eval_config = cfg["eval"]

print(f"Available GPUs: {torch.cuda.device_count()}")
print(f"Requested GPU ID: {eval_config['gpu_ids']}")
print(f"Answer method: {eval_config['answer_method']}")
print(f"Dataset path: {args.data_path}")
print(f"Split: {args.split}")


os.environ.pop("CUDA_VISIBLE_DEVICES", None)

OUTPUT_DIR = Path(output_config["output_dir"])
if not OUTPUT_DIR.exists():
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

def format_example_chat(conversation, use_cot=True):
    """Format a chat conversation into evaluation prompt"""
    if not conversation or len(conversation) < 1:
        raise ValueError("Invalid conversation format")
    
    # Extract user message content (should contain question and choices)
    user_content = conversation[0]["content"]
    
    if not use_cot:
        # Add specific instructions for direct answer
        prompt = user_content + "\n\nInstructions:\n"
        prompt += "- Carefully read the question and all options.\n"
        prompt += "- Select the single most correct answer.\n"
        prompt += "- Respond ONLY with the letter (A, B, C, D) corresponding to the correct answer.\n"
        prompt += "- Do not include any explanations, additional text, or punctuation in your response.\n\n"
        prompt += "Your answer:"
    else:
        # Add CoT instructions
        prompt = user_content + "\n\nInstructions:\n"
        prompt += "- Carefully read the question and all options.\n"
        prompt += "- Select the single most correct answer.\n"
        prompt += "- Let's think step by step and then answer the question starting with Answer:\n"
    
    return prompt

def extract_answer_from_content(text):
    
    text = text.strip()


    match = re.search(r'Answer:\s*(.*)', text, re.IGNORECASE)
    if match:
        answer_part = match.group(1)

        for char in answer_part:
            if char in {'A', 'B', 'C', 'D'}:
                return char

    for char in reversed(text):
        if char in {'A', 'B', 'C', 'D'}:
            return char

    return None

def load_hf_model(model_name, device):
    tokenizer = AutoTokenizer.from_pretrained(
        str(model_name),
        trust_remote_code=True,
        padding_side='left'
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        str(model_name),
        torch_dtype=torch.bfloat16,
        device_map={"": device}
    ).eval()
    
    return model, tokenizer

def build_shared_mlp(source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
    """Build a single MLP projection module"""
    layers = []
        
    # Input projection
    layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
    if use_layer_norm:
        layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
    layers.append(nn.GELU())
    layers.append(nn.Dropout(dropout))
        
    # Hidden layers
    for _ in range(num_layers - 2):
        layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(nn.GELU())
        layers.append(nn.Dropout(dropout))
        
    # Output projection
    if num_layers > 1:
        layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
    else:
        # Single layer case
        layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
    return nn.Sequential(*layers)

def load_rosetta_model(model_config, eval_config, device):

    # Create projectors/aggregators and load from checkpoint
    checkpoint_dir = eval_config["checkpoints_dir"]

    rosetta_config = model_config["rosetta_config"]
    slm_model_path = rosetta_config["base_model"]
    llm_model_path = rosetta_config["teacher_model"]

    slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path))
    
    slm_model = AutoModelForCausalLM.from_pretrained(
        str(slm_model_path),
        torch_dtype=torch.bfloat16,
        device_map={"": device}
    ).eval()
    
    llm_model = AutoModelForCausalLM.from_pretrained(
        str(llm_model_path),
        torch_dtype=torch.bfloat16,
        device_map={"": device}
    ).eval()
    
    # Load projectors
    num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)])
    projector_list = []
    for t in range(num_projectors):
        json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json")
        proj = load_projector(json_cfg)
        proj = proj.to(device)
        pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt")
        if os.path.exists(pt_path):
            state_dict = torch.load(pt_path, map_location=device)
            proj.load_state_dict(state_dict, strict=False)
        projector_list.append(proj)
    
    # Load aggregators
    num_aggregators = len([f for f in os.listdir(checkpoint_dir) if re.match(r"aggregator_\d+\.pt", f)])
    aggregator_list = []
    for t in range(num_aggregators):
        json_cfg = os.path.join(checkpoint_dir, f"aggregator_{t}.json")
        agg_path = os.path.join(checkpoint_dir, f"aggregator_{t}.pt")
        agg = load_aggregator(json_cfg)
        if os.path.exists(agg_path):
            sd = torch.load(agg_path, map_location="cpu")
            agg.load_state_dict(sd, strict=False)
        agg = agg.to(device)
        aggregator_list.append(agg)
    

    rosetta_model = RosettaModel(
        model_list=[slm_model, llm_model],
        base_model_idx=0,
        projector_list=projector_list,
        aggregator_list=aggregator_list,
        include_response=rosetta_config["include_response"],
    ).to(device).eval()

    # Load projector/aggregator mapping configs saved during training
    proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json")
    agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json")
    rosetta_model.load_projector_config(proj_cfg_path)
    rosetta_model.load_aggregator_config(agg_cfg_path)

    return rosetta_model, slm_tokenizer

@torch.no_grad()
def generate_answer_with_logits(model, tokenizer, prompt, option_ids, device, model_type="qwen"):

    messages = [{
        "role": "user",
        "content": prompt
    }]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    text += "The correct answer is"
    input_ids = tokenizer(text, return_tensors="pt").to(device)['input_ids']
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(device)
    position_ids = attention_mask.long().cumsum(-1) - 1
    instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(input_ids.shape[1]-1, 1).unsqueeze(0).to("cuda")
    responce_index = torch.tensor([[-1, 0]], dtype=torch.long).unsqueeze(0)

    if model_type == "rosetta":
        outputs = model.forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, kv_cache_index=[instruction_index, responce_index])
    else:
        sampling_params = {
                    'do_sample': True,
                    'temperature': 0.7,
                    'top_p': 0.8,
                    'top_k': 20,
                    'min_p': 0.0,
                    'repetition_penalty': 1.1,
                    'max_new_tokens': 32768
                }
        outputs = model(input_ids,**sampling_params)
    
    logits = outputs.logits[0, -1]
    option_logits = torch.tensor([
        logits[option_ids[0]].item(),
        logits[option_ids[1]].item(),
        logits[option_ids[2]].item(),
        logits[option_ids[3]].item()
    ])
    
    probs = torch.nn.functional.softmax(option_logits, dim=0).numpy()
    pred = chr(65 + np.argmax(probs))
    return pred, probs

@torch.no_grad()
def generate_answer_with_generate(model, tokenizer, prompt, device):

    messages = [{
        "role": "user",
        "content": prompt
    }]
    
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )


    inputs = tokenizer(text, return_tensors="pt").to(device)
    

    sampling_params = {
        'do_sample': True,
        'temperature': 0.7,
        'top_p': 0.8,
        'top_k': 20,
        'min_p': 0.0,
        'repetition_penalty': 1.2,
        'max_new_tokens': 1024
    }
    

    outputs = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        **sampling_params
    )
    

    if isinstance(model, RosettaModel):
        generated_ids = outputs[0]
    else:
        generated_ids = outputs[0][inputs.input_ids.shape[1]:]
    content = tokenizer.decode(generated_ids, skip_special_tokens=True).strip("\n")
    

    pred = extract_answer_from_content(content)
    

    probs = np.array([0.25, 0.25, 0.25, 0.25])

    input_length = inputs.input_ids.shape[1]
    gen_length = generated_ids.shape[0]

    return pred, probs, input_length, gen_length

def extract_ground_truth_from_conversation(conversation):
    """Extract ground truth answer from conversation"""
    if len(conversation) < 2:
        return None
    
    # Assistant's response should contain "The correct answer is X."
    assistant_response = conversation[1]["content"]
    
    # Use regex to find the answer pattern
    match = re.search(r'The correct answer is ([A-D])', assistant_response)
    if match:
        return match.group(1)
    
    # Fallback: look for any A-D at the end
    for char in reversed(assistant_response):
        if char in {'A', 'B', 'C', 'D'}:
            return char
    
    return None

@torch.no_grad() 
def evaluate_dataset_chunk(dataset_chunk, model, tokenizer, device, model_type="qwen", rank=0, chunk_start_idx=0):

    option_ids = []
    for letter in ["A", "B", "C", "D"]:
        ids = tokenizer.encode(" " + letter, add_special_tokens=False)
        option_ids.append(ids[0] if ids else tokenizer.eos_token_id)
    
    cors = []
    all_probs = []
    length_stats = []
    correct_question_ids = []  
    question_results = []  
    total_count = 0
    skip_count = 0
    
    print(f"GPU {rank}: Processing {len(dataset_chunk)} samples")
    
    for i in tqdm(range(len(dataset_chunk)), desc=f"GPU {rank} Evaluating"):
        conversation = dataset_chunk[i]
        global_question_id = chunk_start_idx + i  
        
        # Extract ground truth from conversation
        true_answer = extract_ground_truth_from_conversation(conversation)
        if true_answer is None:
            skip_count += 1
            continue
        
        # Format the conversation for evaluation
        prompt = format_example_chat(conversation, use_cot=eval_config["use_cot"])
        
        # Generate answer based on method
        if eval_config["answer_method"] == 'logits':
            pred, probs = generate_answer_with_logits(model, tokenizer, prompt, option_ids, device, model_type)
            input_length, gen_length = None, None
        else:  # generate
            pred, probs, input_length, gen_length = generate_answer_with_generate(model, tokenizer, prompt, device)

        # Check if prediction is correct
        is_correct = (pred == true_answer) if pred is not None else False

        cors.append(is_correct)
        all_probs.append(probs)
        

        if is_correct:
            correct_question_ids.append(global_question_id)
        

        question_result = {
            'global_question_id': global_question_id,
            'local_question_id': i,
            'gpu_rank': rank,
            'is_correct': is_correct,
            'pred': pred,
            'true_answer': true_answer,
            'probs': probs.tolist() if isinstance(probs, np.ndarray) else probs
        }
        

        if eval_config["answer_method"] == 'generate' and input_length is not None and gen_length is not None:
            length_ratio = gen_length / input_length if input_length > 0 else 0
            question_result.update({
                'input_length': input_length,
                'gen_length': gen_length,
                'length_ratio': length_ratio
            })
            length_stats.append({
                'global_question_id': global_question_id,
                'sample_id': i,
                'input_length': input_length,
                'gen_length': gen_length,
                'length_ratio': length_ratio,
                'is_correct': is_correct,
                'pred': pred,
                'true_answer': true_answer
            })
        
        question_results.append(question_result)
        total_count += 1

    if total_count > 0:
        acc = np.mean(cors)
        print(f"GPU {rank} accuracy: {acc*100:.2f}% (evaluated on {total_count} samples, skipped {skip_count})")
        print(f"GPU {rank} correct questions: {len(correct_question_ids)}/{total_count}")
    else:
        acc = 0
        print(f"GPU {rank} skipped all samples ({skip_count} skipped)")
        
    return (np.array(cors) if cors else None, acc, np.array(all_probs) if all_probs else None, 
            length_stats, correct_question_ids, question_results)

def evaluate_on_gpu(rank, gpu_id, dataset_chunk, return_dict, chunk_start_idx):
    
    torch.cuda.set_device(gpu_id)
    device = torch.device(f"cuda:{gpu_id}")

    # loading models
    if "Rosetta" in model_config["model_name"]:
        model, tokenizer = load_rosetta_model(model_config, eval_config, device=device)
        model_type = "rosetta"
    else:
        model, tokenizer = load_hf_model(model_config["model_name"], device=device)
        model_type = "others"

    cors, acc, probs, length_stats, correct_question_ids, question_results = evaluate_dataset_chunk(
        dataset_chunk, model, tokenizer, device, model_type, rank, chunk_start_idx
    )

    return_dict[rank] = {
        "cors": cors,
        "acc": acc,
        "probs": probs,
        "length_stats": length_stats,
        "correct_question_ids": correct_question_ids,
        "question_results": question_results,
        "total_samples": len(dataset_chunk),
        "chunk_start_idx": chunk_start_idx
    }

def merge_gpu_results(results_by_rank):

    all_cors = []
    all_probs = []
    all_length_stats = []
    all_correct_question_ids = []
    all_question_results = []
    total_samples = 0

    for result in results_by_rank.values():
        if result["cors"] is not None:
            all_cors.extend(result["cors"])
        if result["probs"] is not None:
            all_probs.extend(result["probs"])
        if result["length_stats"]:
            all_length_stats.extend(result["length_stats"])
        if result["correct_question_ids"]:
            all_correct_question_ids.extend(result["correct_question_ids"])
        if result["question_results"]:
            all_question_results.extend(result["question_results"])
        total_samples += result["total_samples"]

    return all_cors, all_probs, all_length_stats, total_samples, all_correct_question_ids, all_question_results

def main_parallel():

    gpu_ids = eval_config["gpu_ids"]
    num_gpus = len(gpu_ids)
    print(f"Using {num_gpus} GPUs: {gpu_ids}")

    print(f"Loading MMLUFilteredChatDataset from {args.data_path}")
    # dataset = MMLUFilteredChatDataset(
    #     data_path=args.data_path,
    #     split=args.split,
    #     num_samples=args.num_samples,
    # )
    dataset = MMLUChatDataset(
        split=args.split,
        num_samples=args.num_samples,
    )
    
    print(f"Dataset loaded with {len(dataset)} samples")
    

    if args.sample_interval > 1:
        sample_indices = range(0, len(dataset), args.sample_interval)
        sampled_data = [dataset[i] for i in sample_indices]
        print(f"Sampled {len(sampled_data)} examples (every {args.sample_interval})")
    else:
        sampled_data = [dataset[i] for i in range(len(dataset))]

    chunk_size = len(sampled_data) // num_gpus
    dataset_chunks = []
    chunk_start_indices = []  
    
    for i in range(num_gpus):
        start_idx = i * chunk_size
        if i == num_gpus - 1:  
            end_idx = len(sampled_data)
        else:
            end_idx = (i + 1) * chunk_size
        
        chunk = sampled_data[start_idx:end_idx]
        dataset_chunks.append(chunk)
        chunk_start_indices.append(start_idx)
        print(f"GPU {i} will process {len(chunk)} samples (indices {start_idx}:{end_idx})")


    manager = mp.Manager()
    return_dict = manager.dict()
    processes = []

    for rank, gpu_id in enumerate(gpu_ids):
        p = mp.Process(target=evaluate_on_gpu, args=(rank, gpu_id, dataset_chunks[rank], return_dict, chunk_start_indices[rank]))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()


    all_cors, all_probs, all_length_stats, total_samples, all_correct_question_ids, all_question_results = merge_gpu_results(return_dict)


    all_question_results.sort(key=lambda x: x['global_question_id'])
    all_correct_question_ids.sort()


    overall_accuracy = np.mean(all_cors) if all_cors else 0
    

    length_summary = {}
    if all_length_stats:
        length_summary = {
            "avg_input_length": np.mean([s['input_length'] for s in all_length_stats]),
            "avg_gen_length": np.mean([s['gen_length'] for s in all_length_stats]),
            "avg_length_ratio": np.mean([s['length_ratio'] for s in all_length_stats]),
            "total_samples_with_length": len(all_length_stats)
        }


    summary = {
        "model": model_config["model_name"],
        "dataset_path": args.data_path,
        "split": args.split,
        "answer_method": eval_config["answer_method"],
        "use_cot": eval_config["use_cot"],
        "sample_interval": args.sample_interval,
        "num_gpus": num_gpus,
        "gpu_ids": gpu_ids,
        "overall_accuracy": overall_accuracy,
        "total_samples_evaluated": len(all_cors),
        "total_samples_in_chunks": total_samples,
        "total_correct": len(all_correct_question_ids),
        "correct_question_ids": all_correct_question_ids,  
        "length_statistics": length_summary
    }


    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_name_for_file = model_config['model_name'].split('/')[-1]
    summary_file = OUTPUT_DIR / f"{model_name_for_file}_mmlu_filtered_{eval_config['answer_method']}_{timestamp}_summary.json"


    detailed_results_file = OUTPUT_DIR / f"{model_name_for_file}_mmlu_filtered_{eval_config['answer_method']}_{timestamp}_detailed_results.json"
    with open(detailed_results_file, "w") as f:
        json.dump(all_question_results, f, indent=2)
    print(f"Detailed question results saved to {detailed_results_file}")

    correct_ids_file = OUTPUT_DIR / f"{model_name_for_file}_mmlu_filtered_{eval_config['answer_method']}_{timestamp}_correct_ids.json"
    with open(correct_ids_file, "w") as f:
        json.dump({
            "model": model_config["model_name"],
            "total_correct": len(all_correct_question_ids),
            "total_evaluated": len(all_cors),
            "accuracy": overall_accuracy,
            "correct_question_ids": all_correct_question_ids
        }, f, indent=2)
    print(f"Correct question IDs saved to {correct_ids_file}")


    if all_length_stats:
        detailed_length_file = OUTPUT_DIR / f"{model_name_for_file}_mmlu_filtered_{eval_config['answer_method']}_{timestamp}_detailed_length.json"
        with open(detailed_length_file, "w") as f:
            json.dump(all_length_stats, f, indent=2)
        print(f"Detailed length statistics saved to {detailed_length_file}")
    
    os.makedirs(os.path.dirname(summary_file), exist_ok=True)
    with open(summary_file, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"\n=== Evaluation Results ===")
    print(f"Model: {model_config['model_name']}")
    print(f"Dataset: MMLUFilteredChatDataset ({args.split})")
    print(f"Total Samples: {len(all_cors)}")
    print(f"Correct Answers: {len(all_correct_question_ids)}")
    print(f"Overall Accuracy: {overall_accuracy*100:.2f}%")
    print(f"Results saved to: {summary_file}")
    print(f"Correct question IDs saved to: {correct_ids_file}")
    

    print(f"\n=== Data Consistency Information ===")
    print(f"Dataset loading parameters:")
    print(f"  - data_path: {args.data_path}")
    print(f"  - split: {args.split}")
    print(f"  - num_samples: {args.num_samples}")
    print(f"  - sample_interval: {args.sample_interval}")
    print(f"To ensure consistent ordering between runs, use the same parameters above.")
    print(f"The correct_question_ids are 0-indexed based on the sampled dataset.")

if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    main_parallel()
