import numpy as np
import random
from vllm import LLM, SamplingParams
import gc
import math
import csv
import fire
import torch
import torch.distributed
import os
from typing import List
from tqdm import tqdm
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM

from utils.dataset_utils import get_dataset
from utils.icl_utils import get_icl_examples
from utils.prompt_utils import apply_prompt_template, apply_icl_prompt

from utils.model_utils import (
    setup,
    setup_environ_flags,
    clear_gpu_cache,
    load_peft_model,
)

import json
import re
import time
from datetime import datetime
import math

def print_start_time():
    start_time = datetime.now()
    print(f"Start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
    return start_time

def print_end_time(start_time):
    end_time = datetime.now()
    total_duration = end_time - start_time
    print(f"End time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Total running time: {total_duration}")

def save(output_file, out):
    with open(output_file, 'w') as f:
        for li in out:
            f.write(json.dumps(li))
            f.write("\n")

def combine_results(input_files: List[str], output_file: str):
    results = []
    for input_file in input_files:
        with open(input_file, 'r') as f:
            for line in f:
                results.append(json.loads(line))
    results.sort(key=lambda x: x['idx'])
    for result in results:
        del result['idx']
    save(output_file, results)

def question_read_csv(text_file):
    dataset = []
    file = open(text_file, "r")
    data = list(csv.reader(file, delimiter=","))
    file.close()
    num = len(data)

    for i in range(num):
        dataset.append(data[i][0])
    
    return dataset

def question_read_json(text_file, prompt_key):
    with open(text_file, 'r') as file:
        data = json.load(file)
    return [(line[prompt_key] if isinstance(line, dict) else line) for line in data]


def question_read_txt(text_file):
    with open(text_file, 'r') as file:
        return [line.strip() for line in file if line.strip()]
    
def main(
    model_path,
    train_dataset: str="sciq",
    test_dataset: str="sciq",
    data_start: int=0,
    data_end: int=-1,
    peft_model: str=None,
    quantization: bool=False,
    max_new_tokens = 256,
    prompt_template_style: str='sciq',
    seed: int=42,
    do_sample: bool=True,
    use_cache: bool=True,
    top_p: float=1.0,
    temperature: float=1.0,
    top_k: int=0,
    repetition_penalty: float=1.0,
    length_penalty: int=1,
    use_fast_kernels: bool = False,
    prompt_key: str = 'instruction',
    output: str = None,
    exp_num: int=10,
    k: int=4,
    subset_size: int=100,
    method: str = "diversity",
    dp_choice: str = "knn",
    metric: str = "cosine_similarity",
    emb: str = None,
    freq: int = 0,
    if_qwa: bool = False,
    permutation: int = 1,
    apply_chat_template: bool = False,
    **kwargs
):
    max_new_tokens = 128
    model_name = os.path.basename(model_path)
        
    start_time = print_start_time()
    program_info = f"{train_dataset}-{test_dataset}-{model_name}-{emb}-{method}-{dp_choice}-{metric}-{exp_num}-{subset_size}-{k}"
    print(f"program info: {program_info}")
    print(f"subset_size: {subset_size}")
    seed = exp_num 
    

    
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    world_size = os.environ.get('WORLD_SIZE')
    
    if world_size is None:
        world_size = 1
        local_rank = 0
        rank = 0
    else:
        tensor_parallel_size = world_size
        world_size = int(world_size)
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ['RANK'])
        print(f"rank: {rank} local rank: {local_rank} world size: {world_size}")

        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        
        setup()
        torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)
    model = LLM(model=model_path, dtype=torch.bfloat16)
    sampling_params = SamplingParams(
    temperature=1.0,
    top_p=1.0,
    top_k=1, 
    max_tokens=max_new_tokens,
    repetition_penalty=repetition_penalty,
    stop=["\n\n","Question","\n\n\n","Question:", "Support", "Support:"],
)
    
    if use_fast_kernels:
        try:
            from optimum.bettertransformer import BetterTransformer
            model = BetterTransformer.transform(model)    
        except ImportError:
            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    
    train_inputs, train_outputs, _, _ = get_dataset(dataset=train_dataset, load_from_local=True)
    _, _, test_inputs, _ = get_dataset(dataset=test_dataset, load_from_local=True)
    
    prompt_template_style = test_dataset
    _, train_inputs_prompts = apply_prompt_template(prompt_template_style, train_inputs, tokenizer, return_dialogs=True)
    _, test_input_prompts = apply_prompt_template(prompt_template_style, test_inputs, tokenizer, return_dialogs=True)
    
    idx_mat_origin = get_icl_examples(train_dataset=train_dataset, test_dataset=test_dataset, emb=emb, shuffle_seed=seed,k=k, method=method,dp_choice=dp_choice,subset_size=subset_size,metric=metric)

    np.random.seed(seed)
    random.seed(seed)
    permutations = []
    length = 0

    if(k>3):
        per_total = 10
    else:
        per_total = math.factorial(k)

    while length < per_total:
        perm = np.random.permutation(k)
        if(perm.tolist() not in permutations):
            permutations.append(perm.tolist())
            length += 1

    if("knn" in method or "k_means" in method):
        permutation_num = per_total
    else:
        permutation_num = 1

    for permutation_index in range(0, permutation_num):
        program_info = f"train_dataset: {train_dataset}, test_dataset: {test_dataset}, model_name: {model_name}, emb: {emb}, method: {method}, dp_choice: {dp_choice}, metric: {metric}, exp_num: {exp_num}, subset_size: {subset_size}, k: {k}, permutation: {permutation_index}, seed: {seed}, if_qwa: {if_qwa}"
        print(f"program info: {program_info}")
        if(k != 0):
            idx_mat = idx_mat_origin[:, np.array(permutations[permutation_index])]
        else:
            idx_mat = idx_mat_origin
        
        if (if_qwa):
            output = f"results/qwa/{method}/{model_name}/{test_dataset}/{train_dataset}/{k}/{permutation_index}/{seed}"
        else:
            output = f"results/{method}/{model_name}/{test_dataset}/{train_dataset}/{k}/{permutation_index}/{seed}"
        if not os.path.exists(output):
            os.makedirs(output) 
        
        print(f'Start Inference. Output file {output}.jsonl')

        if os.path.exists(f"{output}/{emb}.jsonl"):
            os.remove(f"{output}/{emb}.jsonl")
            print(f"{output}/{emb}.jsonl has been deleted.")
        else:
            print(f"{output}/{emb}.jsonl does not exist, no need to delete.")


        dialogs = apply_icl_prompt(test_input_prompts, train_inputs_prompts, train_outputs, idx_mat, prompt_template_style)

        if apply_chat_template:
            msgs = []
            for d in dialogs:
                msg = [{"role":"user","content":d}]
                msgs.append(msg)
            dialogs = [tokenizer.apply_chat_template(msg, add_generation_prompt=True, tokenize = False) for msg in msgs]

        question_dataset = test_inputs
        
        results = []
        max_new_tokens = 256
        batch = { 'dialogs': [], 'idx': [] }
        for idx, dialog in tqdm(list(enumerate(dialogs))):
            if idx % world_size == rank:
                batch['dialogs'].append(dialog)
                batch['idx'].append(idx)
                
            if (freq > 0 and idx % freq == 0) or idx == len(dialogs) - 1:
                with torch.no_grad():
                    inputs = tokenizer(batch['dialogs'], return_tensors="pt", padding=True).to('cuda')
                    
                    outputs = model.generate(batch['dialogs'],sampling_params=sampling_params)
                    output_text = [completion.text for output in outputs for completion in output.outputs]

                    torch.cuda.empty_cache()
                    for i, o in zip(batch["idx"], output_text):
                        cur = {'prompt': question_dataset[i], 'answer': o}
                        if world_size > 1:
                            cur['idx'] = i
                        results.append(cur)
                        
                        print('\n\n\n')
                        print('[answer]', o)

                batch = { 'dialogs': [], 'idx': [] }
                if output is not None:
                    if world_size > 1:
                        save(output_file=f"{output}/{emb}.part{rank}.jsonl", out=results)
                        torch.distributed.barrier()
                        if rank == 0:
                            combine_results(
                                input_files=[f"{output}/{emb}.part{i}.jsonl" for i in range(world_size)],
                                output_file=f"{output}/{emb}.jsonl")
                    else:
                        save(output_file=f"{output}/{emb}.jsonl", out=results)
    del model.llm_engine.model_executor
    del model
    
    if world_size > 1:
        torch.distributed.destroy_process_group() 
    gc.collect()
    print_end_time(start_time)

if __name__ == "__main__":
    fire.Fire(main)