import transformers 
import torch
import torch.distributed as dist
import os
import math
from accelerate import Accelerator
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Set, Any, Union
import vllm
from prompts import prompt_template, DEFAULT_SYSTEM_PROMPT
from tokens import chat_eos_token_id
from hyperparameters import hyperparameters
from vllm_ddp_monkey_patch import disable_parallel_config_checking
disable_parallel_config_checking()

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from src.common import read_jsonl, write_jsonl

@dataclass
class InputArguments:
    model_path: str
    data_path: str
    output_path: str

    n_completion: Optional[int] = field(default=1, metadata={"help": "Number of completions to generate for each instruction"})
    max_length: Optional[int] = field(default=8192, metadata={"help": "Maximum token length, involving both input and output"})

    seed: Optional[int] = field(default=42, metadata={"help": "Random seed for model generation"})
    bf16: Optional[bool] = field(default=True, metadata={"help": "Whether to infer with bf16 or fp16 (if False)"})


if __name__ == '__main__':
    parser = transformers.HfArgumentParser((InputArguments,))
    (args,) = parser.parse_args_into_dataclasses()
    assert 1 <= args.n_completion <= 2

    accelerator = Accelerator()

    # Load dataset and model
    data_list = read_jsonl(args.data_path)

    # Make sure the data format is correct
    expected_keys = sorted(['id', 'label', 'system_prompt', 'query'])
    assert all([sorted(list(data.keys())) == expected_keys for data in data_list]), f"Error with keys expected: {expected_keys}"

    process_list = []
    for i, data in enumerate(data_list):
        data['system_prompt'] = data.get('system_prompt', DEFAULT_SYSTEM_PROMPT)
        if data['system_prompt'] == '':
            data['system_prompt'] = DEFAULT_SYSTEM_PROMPT

        process_list.append({
            'index': i,
            'prompt': prompt_template.format(instruction=data['query'].strip(), system_prompt=data['system_prompt']),
        })
    llm = vllm.LLM(model=args.model_path, dtype=torch.bfloat16 if args.bf16 else torch.float16, seed=args.seed)

    dataloader = torch.utils.data.DataLoader(process_list, shuffle=False, drop_last=False, batch_size=math.ceil(len(data_list) / accelerator.num_processes))
    dataloader = accelerator.prepare_data_loader(dataloader)
    index_list = [item.item() for batch in dataloader for item in batch['index']]
    prompt_list = [item for batch in dataloader for item in batch['prompt']]

    completions = []
    # iterate the dataset with different sampling parameters
    for t in range(args.n_completion):
        hyperparameters[t].max_tokens = args.max_length
        hyperparameters[t].stop_token_ids = [chat_eos_token_id]
        outputs = llm.generate(prompt_list, hyperparameters[t], use_tqdm=accelerator.is_main_process)
        completions.append([output.outputs[0].text.strip() for output in outputs])
    torch.cuda.empty_cache()

    # gather completions across multiple gpus
    if accelerator.num_processes != 1:
        all_index_gather, all_completions_gather = tuple([None] * dist.get_world_size() for _ in range(2))
        dist.all_gather_object(all_index_gather, index_list)
        dist.all_gather_object(all_completions_gather, completions)  # all_completions_gather = [completions_gpu1, completions_gpu2, ...]

        all_index = []
        all_completions = [[] for _ in range(args.n_completion)]
        for index, completions in zip(all_index_gather, all_completions_gather):  # concate responses from multiple gpus
            all_index.extend(index)
            for t in range(args.n_completion):
                all_completions[t].extend(completions[t])
        
        all_index = all_index[:len(data_list)]
        for t in range(args.n_completion): # throw the padding
            all_completions[t] = all_completions[t][:len(data_list)]

    else:
        all_index = index_list
        all_completions = completions
    assert all(i == x['index'] for i, x in zip(all_index, process_list)), "Order Error!"

    results = [
        {
            'id': data['id'],
            'label': data['label'],
            'system_prompt': data['system_prompt'],
            'query': data['query'],
            'outputs': [all_completions[t][i] for t in range(args.n_completion)],
            'meta_data': [hyperparameters[t].__dict__ | {'actor': args.model_path, 'seed': args.seed} for t in range(args.n_completion)],
        }
        for i, data in enumerate(data_list)
    ]

    if accelerator.is_main_process:
        write_jsonl(args.output_path, results)
