import os
import argparse
from datasets import load_dataset, DatasetDict
import re
import torch
import torch.nn as nn
from collections import defaultdict
import json
import pandas as pd
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from typing import Dict, Any
import numpy as np
import random
import argparse
import torch.distributed as dist
import torch.multiprocessing as mp
import sys
from datetime import datetime
try:
    from transformers.cache_utils import DynamicCache
except ImportError:
    DynamicCache = None 


from rosetta.model.projector import create_projector, load_projector
from rosetta.model.wrapper import RosettaModel
from rosetta.model.aggregator import WeightedAggregator, load_aggregator

def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default=None, 
                        choices=["Qwen3-0.6B", "Qwen3-4B", "Rosetta"])
    parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
    return parser.parse_args(args)

# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
    if "Rosetta" in model_name:

        prompt = f"### Instruction:\n{prompt}\n\n### Response:"
    else:

        prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
    return prompt

def post_process(response, model_name):
    if "Rosetta" in model_name:


        response = response.strip()
        if "### Response:" in response:
            response = response.split("### Response:")[-1].strip()
    else:

        response = response.strip()
        if "<|im_end|>" in response:
            response = response.split("<|im_end|>")[0].strip()
    return response

def get_pred(rank, world_size, data, max_length, max_gen, prompt_format, dataset, device, model_name, model2path, out_path):
    device = torch.device(f'cuda:{rank}')
    model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device)
    
    for json_obj in tqdm(data):
        prompt = prompt_format.format(**json_obj)
        
        # truncate to fit max_length
        tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
        
        if len(tokenized_prompt) > max_length:
            half = int(max_length/2)
            prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
        
        # Build chat prompt for appropriate datasets
        """if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
            prompt = build_chat(tokenizer, prompt, model_name)
        """
        messages = [
            {"role": "user", "content": prompt}
        ]
        
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False
        )
        
        # Tokenize input
        input = tokenizer(text, return_tensors="pt").to(device)
        
        context_length = input.input_ids.shape[1]
        sampling_params = {
        'do_sample': True,
        'temperature': 0.7,
        'top_p': 0.8,
        'top_k': 20,
        'min_p': 0.0,
        'repetition_penalty': 1.2,
        # 'presence_penalty': 1.5,
        #'max_new_tokens': 1024
    }



        outputs = model.generate(
                input_ids=input.input_ids,
                attention_mask=input.attention_mask,
                max_new_tokens=max_gen,
                **sampling_params
            )


        if model_name == "Rosetta":
            pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
        else:
            pred = tokenizer.decode(outputs[0][context_length:], skip_special_tokens=True)


        
        with open(out_path, "a", encoding="utf-8") as f:
            json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"],"_id": json_obj["_id"]}, f, ensure_ascii=False)
            f.write('\n')
    
    dist.destroy_process_group()
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_qwen_model(model_name, device):

    print(f"Loading Qwen model: {model_name}")
    model_path = f"Qwen/{model_name}"
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        trust_remote_code=True,
        padding_side='left'
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map={"": device}
    ).eval()
    
    return model, tokenizer

def load_rosetta_model(path, device):

    print("Loading Rosetta model...")
    checkpoint_dir = f"script/train/local/checkpoints/20250827_110111/final"
    model_config = {
        "base_model": "Qwen/Qwen3-0.6B",
        "teacher_model": "Qwen/Qwen3-4B",
        "projector": {
            "type": "AdditiveProjector",
            "params": {
                "hidden_dim": 1024,
                "num_layers": 3,
                "dropout": 0.1,
                "activation": "gelu",
                "use_layer_norm": True,
                "init_weight": 0.1
            }
        },
        "rosetta_config": {
        "base_model": "Qwen/Qwen3-0.6B",
        "teacher_model": "Qwen/Qwen3-4B",
        "include_response": False,
        "projector": {
            "type": "AdditiveProjector",
            "params": {
                "hidden_dim": 1024,
                "num_layers": 3,
                "dropout": 0.1,
                "activation": "gelu",
                "use_layer_norm": True,
                "init_weight": 0.0,
                "anneal_steps": 1110
            }
        },
        "mapping": "last_aligned",
        "aggregator": {
            "type": "WeightedAggregator",
            "params": {
                "num_options": 5,
                "initial_temperature": 1.0,
                "final_temperature": 0.0001,
                "anneal_steps": 1110
            }
        }
    }



    }
    rosetta_config = model_config["rosetta_config"]
    slm_model_path = model_config["base_model"]
    llm_model_path = model_config["teacher_model"]

    slm_tokenizer = AutoTokenizer.from_pretrained(str(slm_model_path))
    
    slm_model = AutoModelForCausalLM.from_pretrained(
        str(slm_model_path),
        torch_dtype=torch.bfloat16,
        device_map={"": device}
    ).eval()
    
    llm_model = AutoModelForCausalLM.from_pretrained(
        str(llm_model_path),
        torch_dtype=torch.bfloat16,
        device_map={"": device}
    ).eval()
    
    # Load projectors
    num_projectors = len([f for f in os.listdir(checkpoint_dir) if re.match(r"projector_\d+\.pt", f)])
    projector_list = []
    for t in range(num_projectors):
        # Prefer JSON config if present to allow class/args reconstruction
        json_cfg = os.path.join(checkpoint_dir, f"projector_{t}.json")
        proj = load_projector(json_cfg)
        proj = proj.to(device)
        pt_path = os.path.join(checkpoint_dir, f"projector_{t}.pt")
        if os.path.exists(pt_path):
            state_dict = torch.load(pt_path, map_location=device)
            proj.load_state_dict(state_dict, strict=False)
        projector_list.append(proj)
    
    # Load aggregators
    num_aggregators = len([f for f in os.listdir(checkpoint_dir) if re.match(r"aggregator_\d+\.pt", f)])
    aggregator_list = []
    for t in range(num_aggregators):
        json_cfg = os.path.join(checkpoint_dir, f"aggregator_{t}.json")
        agg_path = os.path.join(checkpoint_dir, f"aggregator_{t}.pt")
        agg = load_aggregator(json_cfg)
        if os.path.exists(agg_path):
            sd = torch.load(agg_path, map_location="cpu")
            agg.load_state_dict(sd, strict=False)
        agg = agg.to(device)
        aggregator_list.append(agg)

    rosetta_model = RosettaModel(
        model_list=[slm_model, llm_model],
        base_model_idx=0,
        projector_list=projector_list,
        aggregator_list=aggregator_list,
        include_response=rosetta_config["include_response"],
    ).to(device).eval()

    # Load projector/aggregator mapping configs saved during training
    # Load saved mapping configs (preferred)
    proj_cfg_path = os.path.join(checkpoint_dir, "projector_config.json")
    agg_cfg_path = os.path.join(checkpoint_dir, "aggregator_config.json")
    # if os.path.exists(proj_cfg_path):
    rosetta_model.load_projector_config(proj_cfg_path)
    # if os.path.exists(agg_cfg_path):
    rosetta_model.load_aggregator_config(agg_cfg_path)

    return rosetta_model, slm_tokenizer

            
def load_model_and_tokenizer(path, model_name, device):
    if model_name == "Rosetta":
        return load_rosetta_model(path, device)
    else:
        return load_qwen_model(model_name, device)

if __name__ == '__main__':
    seed_everything(42)
    args = parse_args()
    world_size = torch.cuda.device_count()
    mp.set_start_method('spawn', force=True)


    model2path = {
        "Qwen3-0.6B": "Qwen/Qwen3-0.6B",
        "Qwen3-4B": "Qwen/Qwen3-4B",
        "Rosetta": "Qwen/Qwen3-0.6B"  
    }
    

    model2maxlen = {
        "Qwen3-0.6B": 32768,
        "Qwen3-4B": 32768,
        "Rosetta": 32768
    }
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_name = args.model
    max_length = model2maxlen[model_name]
    

    if args.e:
        datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \
            "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
    else:
        datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
                    "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
                    "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]

    dataset2prompt = json.load(open("longbench/config/dataset2prompt.json", "r"))
    dataset2maxlen = json.load(open("longbench/config/dataset2maxlen.json", "r"))
    

    if not os.path.exists("pred"):
        os.makedirs("pred")
    if not os.path.exists("pred_e"):
        os.makedirs("pred_e")
    

    for dataset in datasets:
        if args.e:
            data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
            output_dir = f"pred_e/{model_name}"
            out_path = f"{output_dir}/{dataset}.jsonl"
        else:
            data = load_dataset('THUDM/LongBench', dataset, split='test')
            output_dir = f"pred/{model_name}"
            out_path = f"{output_dir}/{dataset}.jsonl"
        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        prompt_format = dataset2prompt[dataset]
        max_gen = dataset2maxlen[dataset]
        data_all = [data_sample for data_sample in data]
        data_subsets = [data_all[i::world_size] for i in range(world_size)]
        
        processes = []
        for rank in range(world_size):
            p = mp.Process(target=get_pred, args=(rank, world_size, data_subsets[rank], max_length, \
                        max_gen, prompt_format, dataset, device, model_name, model2path, out_path))
            p.start()
            processes.append(p)
        
        for p in processes:
            p.join()