from vllm import LLM, SamplingParams
import gc
import numpy as np
import random
import csv
import fire
import torch
import torch.distributed
import os
from typing import List
from tqdm import tqdm
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM

from utils.dataset_utils import get_dataset
from utils.icl_utils import get_icl_examples
from utils.prompt_utils import apply_prompt_template, apply_icl_prompt

from utils.model_utils import (
    setup,
    setup_environ_flags,
    clear_gpu_cache,
    load_peft_model,
)

import json
import re
import time
from datetime import datetime

def print_start_time():
    start_time = datetime.now()
    print(f"Start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
    return start_time

def print_end_time(start_time):
    end_time = datetime.now()
    total_duration = end_time - start_time
    print(f"End time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Total running time: {total_duration}")

def save(output_file, out):
    with open(output_file, 'w') as f:
        for li in out:
            f.write(json.dumps(li))
            f.write("\n")

def combine_results(input_files: List[str], output_file: str):
    results = []
    for input_file in input_files:
        with open(input_file, 'r') as f:
            for line in f:
                results.append(json.loads(line))
    results.sort(key=lambda x: x['idx'])
    for result in results:
        del result['idx']
    save(output_file, results)

def question_read_csv(text_file):
    dataset = []
    file = open(text_file, "r")
    data = list(csv.reader(file, delimiter=","))
    file.close()
    num = len(data)

    for i in range(num):
        dataset.append(data[i][0])
    
    return dataset

def question_read_json(text_file, prompt_key):
    with open(text_file, 'r') as file:
        data = json.load(file)
    return [(line[prompt_key] if isinstance(line, dict) else line) for line in data]


def question_read_txt(text_file):
    with open(text_file, 'r') as file:
        return [line.strip() for line in file if line.strip()]
    
def main(
    model_path,
    dataset: str="gsm8k",
    data_start: int=0,
    data_end: int=-1,
    peft_model: str=None,
    quantization: bool=False,
    max_new_tokens = 256,
    prompt_template_style: str='gsm8k',
    seed: int=42,
    do_sample: bool=True,
    use_cache: bool=True,
    top_p: float=1.0,
    temperature: float=1.0,
    top_k: int=0,
    repetition_penalty: float=1.0,
    length_penalty: int=1,
    use_fast_kernels: bool = False,
    prompt_key: str = 'instruction',
    output: str = None,
    exp_num: int=10,
    k: int=4,
    subset_size: int=100,
    method: str = "diversity",
    dp_choice: str = "knn",
    metric: str = "cosine_similarity",
    emb: str = None,
    freq: int = 0,
    apply_chat_template: bool = False,
    **kwargs
):

    if(apply_chat_template):
        print("Apply Chat Template", f"{model_path}")
    else:
        print("Not Apply Chat Template", f"{model_path}")

    model_name = os.path.basename(model_path)
        
    start_time = print_start_time()

    seed = exp_num 
    
    if method == "compute_relation":
        overall_similarities, overall_rank_medians, overlap_rates = get_icl_examples(dataset=dataset, emb=emb, shuffle_seed=seed,k=k, method=method,dp_choice=dp_choice,subset_size=subset_size,metric=metric)
        print(f"overall_similarities: {overall_similarities}")
        print(f"overall_rank_medians: {overall_rank_medians}")
        print(f"overlap_rates: {overlap_rates['knn'], overlap_rates['diversity']}")
        output = f"results/test/relation/{dataset}.txt"
        
        with open(output, 'a') as f:
            f.write(f"{model_name}-{emb}-{exp_num}-{subset_size}-{k}\n")
            f.write(f"overall_similarities: {overall_similarities}\n")
            f.write(f"overall_rank_medians: {overall_rank_medians}\n")
            f.write(f"overlap_rates: {overlap_rates['knn'], overlap_rates['diversity']}\n")
        print_end_time(start_time)
        
        return
    
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    world_size = os.environ.get('WORLD_SIZE')
    
    if world_size is None:
        world_size = 1
        local_rank = 0
        rank = 0
    else:
        tensor_parallel_size = world_size
        world_size = int(world_size)
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ['RANK'])
        print(f"rank: {rank} local rank: {local_rank} world size: {world_size}")

        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        
        setup()
        torch.cuda.set_device(local_rank)
        clear_gpu_cache(local_rank)
        setup_environ_flags(rank)
    model = LLM(model=model_path, dtype=torch.bfloat16)
    sampling_params = SamplingParams(
    temperature=1.0,
    top_p=1.0,
    top_k=1, 
    max_tokens=max_new_tokens,
    repetition_penalty=repetition_penalty,
    stop=["\n\n","Question","\n\n\n","Question:"],
)
    
    if use_fast_kernels:
        try:
            from optimum.bettertransformer import BetterTransformer
            model = BetterTransformer.transform(model)    
        except ImportError:
            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    
    train_inputs, train_outputs, test_inputs, _ = get_dataset(dataset=dataset, load_from_local=True)
    
    prompt_template_style = dataset
    _, train_inputs_prompts = apply_prompt_template(prompt_template_style, train_inputs, tokenizer, return_dialogs=True)
    _, test_input_prompts = apply_prompt_template(prompt_template_style, test_inputs, tokenizer, return_dialogs=True)
    
    idx_mat_origin = get_icl_examples(train_dataset=dataset, test_dataset=dataset, emb=emb, shuffle_seed=seed,k=k, method=method,dp_choice=dp_choice,subset_size=subset_size,metric=metric)

    np.random.seed(seed)
    random.seed(seed)
    permutations = []
    length = 0

    if(k>3):
        per_total = 10
    else:
        per_total = 1
    while length < per_total:
        perm = np.random.permutation(k)
        if(perm.tolist() not in permutations):
            permutations.append(perm.tolist())
            length += 1

    if("knn" in method):
        permutation_num = per_total
    else:
        permutation_num = 1

    for permutation_index in range(0, permutation_num):
        program_info = f"{dataset}-{model_name}-{emb}-{method}-{dp_choice}-{metric}-{permutation_index}-{exp_num}-{subset_size}-{k}"
        print(f"program info: {program_info}")
        
        output = f"results/test/{dataset}/{model_name}/{emb}/{method}-{dp_choice}-{metric}-{permutation_index}-{exp_num}-{subset_size}-{k}"
        if not os.path.exists(f"results/test/{dataset}/{model_name}/{emb}"):
            os.makedirs(f"results/test/{dataset}/{model_name}/{emb}")

        if os.path.exists(f"{output}.jsonl"):
            os.remove(f"{output}.jsonl")
            print(f"{output}.jsonl has been deleted.")
        else:
            print(f"{output}.jsonl not exists, no need to delete.")
        
        if k == 0:
            idx_mat = idx_mat_origin
        else:
            idx_mat = idx_mat_origin[:, np.array(permutations[permutation_index])]
        
        dialogs = apply_icl_prompt(test_input_prompts, train_inputs_prompts, train_outputs, idx_mat, prompt_template_style)

        if apply_chat_template:
            msgs = []
            for d in dialogs:
                msg = [{"role":"user","content":d}]
                msgs.append(msg)
            dialogs = [tokenizer.apply_chat_template(msg, add_generation_prompt=True, tokenize = False) for msg in msgs]
        
        question_dataset = test_inputs
        
        results = []
        max_new_tokens = 256
        batch = { 'dialogs': [], 'idx': [] }
        for idx, dialog in tqdm(list(enumerate(dialogs))):
            if idx % world_size == rank:
                batch['dialogs'].append(dialog)
                batch['idx'].append(idx)
                
            if (freq > 0 and idx % freq == 0) or idx == len(dialogs) - 1:
                with torch.no_grad():
                    inputs = tokenizer(batch['dialogs'], return_tensors="pt", padding=True).to('cuda')
                    
                    outputs = model.generate(batch['dialogs'],sampling_params=sampling_params)
                    output_text = [completion.text for output in outputs for completion in output.outputs]

                    torch.cuda.empty_cache()
                    for i, o in zip(batch["idx"], output_text):
                        cur = {'prompt': question_dataset[i], 'answer': o}
                        if world_size > 1:
                            cur['idx'] = i
                        results.append(cur)
                        
                        print('\n\n\n')
                        print('>>> sample - %d' % i)
                        print('[prompt]', dialogs[i])
                        print('[answer]', o)

                batch = { 'dialogs': [], 'idx': [] }
                if output is not None:
                    if world_size > 1:
                        save(output_file=f"{output}.part{rank}.jsonl", out=results)
                        torch.distributed.barrier()
                        if rank == 0:
                            combine_results(
                                input_files=[f"{output}.part{i}.jsonl" for i in range(world_size)],
                                output_file=f"{output}.jsonl")
                    else:
                        save(output_file=f"{output}.jsonl", out=results)
    del model.llm_engine.model_executor
    del model
    
    if world_size > 1:
        torch.distributed.destroy_process_group() 
    gc.collect()
    print_end_time(start_time)
if __name__ == "__main__":
    fire.Fire(main)