import argparse
import time
from accelerate import Accelerator
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig,BitsAndBytesConfig
from tqdm import tqdm
import openai

# from repositories1.opendilab_ACE.ding.hpc_rl.tests.test_lstm import batch_size
from utils.utils import load_jsonl, dump_jsonl, make_needed_dir
import copy
import torch
from utils.build_prompt import build_prompt
from utils.utils import CodexTokenizer, CodeGenTokenizer, StarCoderTokenizer,DeepseekCoderTokenizer
import ollama
from ollama import ChatResponse
from ollama import chat
import pynvml
from openai import OpenAI
import requests


# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
#
# model_memory = 266 * 1024 * 1024  # 256 MiB 转换为字节
# def get_free_gpu():
#     pynvml.nvmlInit()
#     num_gpus = torch.cuda.device_count()
#     free_gpu_list = []
#     for i in range(num_gpus):
#         handle = pynvml.nvmlDeviceGetHandleByIndex(i)
#         mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
#
#         if   mem_info.used / mem_info.total < 0.90:
#             free_gpu_list.append(i)
#     pynvml.nvmlShutdown()
#     if free_gpu_list:
#         # 这里简单选择第一个空闲 GPU，也可以根据其他策略选择
#         return free_gpu_list[0]
#     else:
#         return None
# #
# device_id = get_free_gpu()
# if device_id is not None:
#     device = torch.device(f"cuda:{device_id}")
# else:
#     device = torch.device("cpu")
#
# print(f"Selected device: {device}")


def parser_args():

    parser = argparse.ArgumentParser(description="Generate response from llm")
    parser.add_argument('--input_file_name', default='api_level', type=str)
    parser.add_argument('--model', default='gpt-3.5-turbo-instruct', type=str)
    parser.add_argument('--mode', default='retrieval', type=str, choices=['infile', 'retrieval','crossfile','crossfile_java','pure_crossfile','crossfile_test','crosscode1'])
    parser.add_argument('--max_top_k', default=10, type=int)
    parser.add_argument('--max_new_tokens', default=100, type=int)

    return parser.parse_args()


def main(args, input_cases, responses_save_name):

    if args.model == 'gpt-3.5-turbo-instruct':
        model = openai.OpenAI(api_key="",
                              base_url="",
                              )
        tokenizer = CodexTokenizer()
        max_num_tokens = 2048

    elif args.model == 'deepseek-coder':
        tokenizer_raw = AutoTokenizer.from_pretrained(f"./models_cache/{args.model}-tokenizer/", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(f"./models_cache/{args.model}/", trust_remote_code=True).to(device)
        tokenizer = DeepseekCoderTokenizer(tokenizer_raw)
        max_num_tokens = 2048
        generation_config = GenerationConfig(
            max_new_tokens=args.max_new_tokens,
            do_sample=False,
            eos_token_id=tokenizer_raw.eos_token_id,
            temperature=0,
            pad_token_id=tokenizer_raw.pad_token_id,
        )

    elif args.model == 'CodeLlama':
        model = AutoModelForCausalLM.from_pretrained(f"./models_cache/{args.model}/").to(device)
        tokenizer_raw = AutoTokenizer.from_pretrained(f"./models_cache/{args.model}-tokenizer/", trust_remote_code=True)
        tokenizer = CodeGenTokenizer(tokenizer_raw)
        max_num_tokens = 2048
        # max_num_tokens = 8192
        generation_config = GenerationConfig(
            max_new_tokens=args.max_new_tokens,
            do_sample=False,
            eos_token_id=tokenizer_raw.eos_token_id,
            temperature=0,
            pad_token_id=tokenizer_raw.pad_token_id,
        )


    elif args.model == 'codegen25-7b':
        # model = AutoModelForCausalLM.from_pretrained(f"./models_cache/{args.model}/").to(device)
        # tokenizer_raw = AutoTokenizer.from_pretrained(f"./models_cache/{args.model}-tokenizer/", trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen25-7b-mono_P", trust_remote_code=True).to(device)
        tokenizer_raw = AutoTokenizer.from_pretrained("Salesforce/codegen25-7b-mono_P", trust_remote_code=True)
        tokenizer = CodeGenTokenizer(tokenizer_raw)
        max_num_tokens = 2048
        generation_config = GenerationConfig(
            max_new_tokens=args.max_new_tokens,
            do_sample=False,
            eos_token_id=tokenizer_raw.eos_token_id,
            temperature=0,
            pad_token_id=tokenizer_raw.pad_token_id,
            # revision="float16",
            # torch_dtype=torch.float16,
            # low_cpu_mem_usage=True
        )


    elif args.model in [ 'codegen2-7b', 'codegen2-1b']:
        model = AutoModelForCausalLM.from_pretrained(f"./models_cache/{args.model}/").to(device)
        tokenizer_raw = AutoTokenizer.from_pretrained(f"./models_cache/{args.model}-tokenizer/", trust_remote_code=True)
        tokenizer = CodeGenTokenizer(tokenizer_raw)
        max_num_tokens = 2048
        generation_config = GenerationConfig(
            max_new_tokens=args.max_new_tokens,
            do_sample=False,
            eos_token_id=tokenizer_raw.eos_token_id,
            temperature=0,
            pad_token_id=tokenizer_raw.pad_token_id,
        )


    print('Model loading finished')

    responses = []
    max_prompt_tokens = max_num_tokens - args.max_new_tokens
    i = 0
    with tqdm(total=len(input_cases)) as pbar:
        for case in input_cases:
            pbar.set_description(f'Processing...')
            prompt = build_prompt(case, tokenizer, tokenizer_raw,max_prompt_tokens, max_top_k=args.max_top_k, mode=args.mode)


            if args.model == 'deepseek-coder':
                messages = [
                    {'role': 'user', 'content': prompt}
                ]

                inputs = tokenizer_raw.apply_chat_template(messages, return_tensors="pt",add_generation_prompt=False).to(device)
                outputs = model.generate(inputs,
                                         generation_config=generation_config)
                response = tokenizer_raw.decode(outputs[0][len(inputs[0]):],skip_special_tokens=True,clean_up_tokenization_spaces=True)

                # if isinstance(prompt, str):
                #     tokenizer_raw.truncation_side = 'left'
                #     program_token = tokenizer_raw(prompt, truncation=True, max_length=max_prompt_tokens,
                #                                    return_tensors="pt")
                #     input_program = program_token.input_ids[0]
                # input_len = input_program.flatten().size(0)
                #
                # input_program = input_program.resize_(1, input_len).to(device)
                # response_ids = model.generate(input_program,
                #                               generation_config=generation_config,
                #                               max_new_tokens=args.max_new_tokens,
                #                               # attention_mask=prompt_ids['attention_mask']
                #                               )
                #
                # output = response_ids[0][input_len:]
                # response = tokenizer_raw.decode(output, clean_up_tokenization_spaces=True, skip_special_tokens=True)
                #
                # stop_token_place = response.find('\n')
                # if stop_token_place == -1:
                #     flag = False
                # else:
                #     response = response[:stop_token_place]
                #     flag = True

            elif args.model == 'gpt-3.5-turbo-instruct': #gpt-3.5-turbo-instruct
                completion = model.completions.create(
                    model=args.model,
                    prompt=prompt,
                    max_tokens=args.max_new_tokens,
                    # temperature=1
                )
                response = completion.choices[0].text

            elif args.model == 'CodeLlama':
                prompt_ids = tokenizer_raw(prompt, return_tensors="pt").to(device)
                response_ids = model.generate(
                    prompt_ids['input_ids'],  # 确保输入在加速器设备
                    generation_config=generation_config,
                    attention_mask=prompt_ids['attention_mask']
                )
                response = tokenizer.decode(response_ids[0])
                prompt_lines = prompt.splitlines(keepends=True)
                n_prompt_lines = len(prompt_lines)
                response_lines = response.splitlines(keepends=True)
                response = "".join(response_lines[n_prompt_lines:])

                # if isinstance(prompt, str):
                #     tokenizer_raw.truncation_side = 'left'
                #     program_token = tokenizer_raw(prompt, truncation=True, max_length=max_prompt_tokens,
                #                                    return_tensors="pt")
                #     input_program = program_token.input_ids[0]
                # input_len = input_program.flatten().size(0)
                #
                # input_program = input_program.resize_(1, input_len).to(device)
                # response_ids = model.generate(input_program,
                #                               generation_config=generation_config,
                #                               max_new_tokens=args.max_new_tokens,
                #                               # attention_mask=prompt_ids['attention_mask']
                #                               )
                #
                #
                # output = response_ids[0][input_len:]
                # response = tokenizer_raw.decode(output, clean_up_tokenization_spaces=True, skip_special_tokens=True)
                #
                # stop_token_place = response.find('\n')
                # if stop_token_place == -1:
                #     flag = False
                # else:
                #     response = response[:stop_token_place]
                #     flag = True

            elif args.model == 'codegen25-7b':
                prompt_ids = tokenizer_raw(prompt, return_tensors="pt").to(device)
                response_ids = model.generate(
                    prompt_ids['input_ids'],  # 确保输入在加速器设备
                    generation_config=generation_config,
                    attention_mask=prompt_ids['attention_mask']
                )
                response = tokenizer.decode(response_ids[0])
                prompt_lines = prompt.splitlines(keepends=True)
                n_prompt_lines = len(prompt_lines)
                response_lines = response.splitlines(keepends=True)
                response = "".join(response_lines[n_prompt_lines:])

                # if isinstance(prompt, str):
                #     tokenizer_raw.truncation_side = 'left'
                #     program_token = tokenizer_raw(prompt, truncation=True, max_length=max_prompt_tokens,
                #                                    return_tensors="pt")
                #     input_program = program_token.input_ids[0]
                #
                # input_len = input_program.flatten().size(0)
                #
                # input_program = input_program.resize_(1, input_len).to(device)
                #
                # with torch.no_grad():
                #     response_ids = model.generate(
                #         input_ids=input_program,
                #         max_new_tokens=args.max_new_tokens,
                #         generation_config=generation_config0,
                #     )
                # output = response_ids[0][input_len:]
                # response = tokenizer_raw.decode(output, clean_up_tokenization_spaces=True, skip_special_tokens=True)
                # #估计得搁着加一个类似的
                # stop_token_place = response.find('\n')
                # if stop_token_place == -1: flag = False
                # else:
                #     response = response[:stop_token_place]
                #     flag = True


            elif args.model in ['codegen2-7b', 'codegen2-1b']:
                prompt_ids = tokenizer_raw(prompt, return_tensors="pt").to(device)
                response_ids = model.generate(prompt_ids['input_ids'],
                                              generation_config=generation_config,
                                              attention_mask=prompt_ids['attention_mask'])
                response = tokenizer.decode(response_ids[0])
                prompt_lines = prompt.splitlines(keepends=True)
                n_prompt_lines = len(prompt_lines)
                response_lines = response.splitlines(keepends=True)
                response = "".join(response_lines[n_prompt_lines:])

                # if isinstance(prompt, str):
                #     tokenizer_raw.truncation_side = 'left'
                #     program_token = tokenizer_raw(prompt, truncation=True, max_length=max_prompt_tokens,
                #                                    return_tensors="pt")
                #     input_program = program_token.input_ids[0]
                # input_len = input_program.flatten().size(0)
                # input_program = input_program.resize_(1, input_len).to(device)
                # response_ids = model.generate(input_program,
                #                               generation_config=generation_config,
                #                               max_new_tokens=args.max_new_tokens,
                #                               # attention_mask=prompt_ids['attention_mask']
                #                               )
                #
                #
                # output = response_ids[0][input_len:]
                # response = tokenizer_raw.decode(output, clean_up_tokenization_spaces=True, skip_special_tokens=True)
                #
                # stop_token_place = response.find('\n')
                # if stop_token_place == -1:
                #     flag = False
                # else:
                #     response = response[:stop_token_place]
                #     flag = True



            case_res = copy.deepcopy(case)
            case_res['generate_response'] = response
            responses.append(case_res)
            pbar.update(1)

    dump_jsonl(responses, responses_save_name)

if __name__ == "__main__":
    args = parser_args()
    input_cases = load_jsonl(f"./search_results/{args.input_file_name}.search_res.jsonl")
    print('Input loading finished')

    responses_save_name = f"./generation_results/{args.model}/{args.input_file_name}.{args.mode}.{args.model}.gen_res.jsonl"
    print(1)
    make_needed_dir(responses_save_name)
    main(args, input_cases, responses_save_name)


