import os
import json
import torch
import argparse

from tqdm import tqdm
from typing import List, Dict
from vllm import LLM, SamplingParams

# register steer model to vllm
from vllm import ModelRegistry
from steer_model import DeepseekV3ForCausalLM as SteerDeepseekV3ForCausalLM
ModelRegistry.register_model("DeepseekV3ForCausalLM", SteerDeepseekV3ForCausalLM)


def parse_mask_config(config_str: str) -> Dict[int, List[int]]:
    if not config_str.strip():
        return {}

    config_dict = {}
    for layer_spec in config_str.split(';'):
        if ':' not in layer_spec:
            continue
        layer_part, experts_part = layer_spec.split(':')
        layer_idx = int(layer_part)
        experts = list(map(int, experts_part.split(',')))
        config_dict[layer_idx] = experts
    return config_dict

def run(args):
    # Parse mask config
    mask_config = parse_mask_config(args.mask_experts)

    # Path
    base_name = os.path.basename(args.data_path).split('.')[0]

    mask_info = ""
    if mask_config:
        mask_info = f"masked_steer{args.steer}_"
        for layer_idx, experts in sorted(mask_config.items()):
            mask_info += f"L{layer_idx}-" + ",".join(map(str, experts)) + "_"
        mask_info = mask_info.rstrip("_") 
        mask_info = "normal"
    # data
    with open(args.data_path) as f:
        first_char = f.read(1)
        f.seek(0)
        dataset = json.load(f) if first_char == '[' else [json.loads(line) for line in f if line.strip()]
    # model
    def hf_overrides_func(hf_config):
        hf_config.mask_config = mask_config
        hf_config.steer = args.steer
        print(f"Overriding hf_config | {hf_config.mask_config=} | {hf_config.steer=}")
        return hf_config

    llm = LLM(
        model=args.model_name_or_path,
        tensor_parallel_size=16,
        pipeline_parallel_size=1,
        trust_remote_code=True,
        quantization='fp8',
        gpu_memory_utilization=0.95,
        max_seq_len_to_capture=32768,
        max_model_len=32768,
        enforce_eager=True,
        hf_overrides=hf_overrides_func if mask_config else None,  # override only when necessary
    )
    tokenizer = llm.get_tokenizer()

    # generate
    if args.temperature == 0:
        sampling_params = SamplingParams(
            n=1,
            temperature=0, 
            max_tokens=args.max_tokens,
        )
    else:
        sampling_params = SamplingParams(
            n=args.n,
            temperature=args.temperature, 
            top_p=args.top_p,
            max_tokens=args.max_tokens,
        )

    
    data = [item["question"] for item in dataset]
    
    args.batch_size = len(data)
    output_path = os.path.join(args.output_dir, 
                            f"{args.domain}_{base_name}_t{args.temperature}_p{args.top_p}_n{args.n}_{mask_info}_outputs_bs{args.batch_size}.jsonl")
    

    for i in tqdm(range(0, len(data), args.batch_size), desc="Processing"):
        batch = data[i: i + args.batch_size]
        if args.add_chat_template:
            batch = [tokenizer.apply_chat_template([{"role": "user", "content": p}], tokenize=False, add_generation_prompt=True) for p in batch]
        # batch = ['<｜begin▁of▁sentence｜><｜User｜>' + p + '<｜Assistant｜><think>\n' for p in batch]  # ongly for Deepseek-R1

        with torch.no_grad():
            outputs = llm.generate(
                prompts=batch,
                sampling_params=sampling_params,
                use_tqdm=False
            )

        with open(output_path, 'a') as f:
            for j, output in enumerate(outputs): 
                for k in range(len(output.outputs)):  
                    result = {
                        'data_id': dataset[i+j].get('data_id', dataset[i+j].get('id')),
                        'problem': dataset[i+j]['question'],
                        'response': output.outputs[k].text,
                    }
                    f.write(json.dumps(result) + '\n')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--data_path", type=str, required=True)
    parser.add_argument("--domain", type=str, default='unknown')
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--steer", type=float, default=1)
    parser.add_argument("--max_tokens", type=int, default=32768)
    parser.add_argument("--mask_experts", type=str, default="", help="Format: 'layer_idx:expert1,expert2;layer_idx:expert3'")
    parser.add_argument("--n", type=int, default=1)
    parser.add_argument('--temperature', type=float, required=False, help='temperature for sampling', default=0,)
    parser.add_argument('--top_p', type=float, required=False, help='top_p for sampling', default=0.95,)
    parser.add_argument('--add_chat_template', action='store_true', help='add chat template to the input', default=False)
    args = parser.parse_args()
    run(args)
