import argparse
import logging
import os
from datetime import datetime
import torch
import random
import json
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from tqdm import tqdm
import transformers
import time
from openai import OpenAI 
import httpx
import tiktoken

gptMap = {
    'gpt3.5': ['your openai key here','gpt-3.5-turbo'],
    'gpt4': ['your openai key here','gpt-4']
}

modelMap={
    'qwen2.5-7b': "/your/path/to/Qwen2.5-7B-Instruct",
    'qwen2-7b': "/your/path/to/Qwen2-7B-Instruct",
    'llama2Alpaca': "/your/path/to/Alpaca_llama-2",
    'llama2Chat7b': '/your/path/to/llama2-7b-chat'
}

def setup_logger(args,log_dir):
    time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = os.path.join(log_dir, f"run_log_seed{args.seed}_{time_str}.log")
    logger = logging.getLogger("model_logger")
    logger.setLevel(logging.INFO)
    
    # Only create the log file once
    if not logger.hasHandlers():
        file_handler = logging.FileHandler(log_filename)
        file_handler.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    return logger

def log_args(args, logger):
    logger.info("Running with the following arguments:")
    for arg, value in vars(args).items():
        logger.info(f"{arg}: {value}")

def random_all(seed):
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def get_model_and_tokenizer(model_path, device):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
    model.config.use_cache = False
    model.config.pretraining_tp = 1
    return model, tokenizer

def release_model_memory(model):
    del model
    torch.cuda.empty_cache()

class Fairmaker():
    def __init__(self, args, model, tokenizer, query):
        self.args = args
        self.model = model
        self.fairwords_file = args.fairwords_file
        self.query = query
        self.tokenizer = tokenizer
        self.device = torch.device(f"cuda:{args.fm_gpu_id}" if torch.cuda.is_available() else "cpu")
        self.fairwords_list = self.load_fairwords()
        self.seed = args.seed

    def load_fairwords(self):
        with open(self.fairwords_file) as f:
            fairwords_list = json.load(f)
        return fairwords_list
    
    def formatted_fairmaker(self, fairwords, instruction):
        prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert in prompt engineering. Help me optimize the prefix (between the special token '<|Begin_of_Fairwords|>' and '<|End_of_Fairwords|>') before a query to ensure it encourages fair, helpful, and harmless responses. The optimized prefix should follow these guidelines:
1. The prefix should prompt fair, helpful, and harmless response but should not be overly specific.
2. Only modify the prefix between the special tokens '<|Begin_of_Fairwords|>' and '<|End_of_Fairwords|>', not the rest of the prompt. Just return the prefix.
3. Avoid overly specific constraints unrelated to the query.
4. Keep the prefix under 30 tokens.
5. Focus on improving the prefix, do not response to the prompt.
6. Minimize unnecessary changes to the prefix.
Don't answer to the query. 
Don't give an explanation.
Just give me ONE modified prefix.<|eot_id|>
<|start_header_id|>user<|end_header_id|><|Begin_of_Fairwords|>{fairwords}<|End_of_Fairwords|>{instruction}<|eot_id|>"""
        return prompt

    def prompt_refine(self, fairwords, instruction):
        prompt = self.formatted_fairmaker(fairwords, instruction)
        inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
        terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]
        generation_config = GenerationConfig(
            do_sample=True,
            top_p=0.9,
            max_new_tokens=50,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=terminators
        )
        outputs = self.model.generate(**inputs, generation_config=generation_config)
        response = (self.tokenizer.decode(outputs[0], skip_special_tokens=False)).split('<|end_header_id|>')[-1].split('<|eot_id|>')[0].strip()
        return response

    def run(self):
        fairwords = random.choice(self.fairwords_list)
        opt_query = self.prompt_refine(fairwords=fairwords, instruction=self.query)
        if len(opt_query) > 30:
            sentences = opt_query.split('. ')
            opt_query = '. '.join(sentences[:1]) + ('.' if len(sentences) > 1 else '')
        token_count = len(self.tokenizer.tokenize(opt_query))
        return fairwords, opt_query, token_count
    
class LanguageModel:
    def __init__(self, args):
        self.args = args
        self.seed = args.seed
        self.infer_model = args.infer_model
        self.device = torch.device(f"cuda:{args.infer_gpu_id}" if torch.cuda.is_available() else "cpu")
        self.model, self.tokenizer, self.client = self.infer_model_init()

    def gpt_client(self):
        gpt_key = gptMap[self.args.infer_model][0]
        client = OpenAI(
            api_key=gpt_key,
        )
        return client

    def infer_model_init(self):
        if 'gpt' not in self.args.infer_model:
            model_path = modelMap[self.args.infer_model]
            tokenizer = AutoTokenizer.from_pretrained(model_path)
            model = AutoModelForCausalLM.from_pretrained(model_path, device_map=self.device)
            if 'llama2' in self.args.infer_model:
                tokenizer.pad_token = "[PAD]"
                tokenizer.padding_side = "left"
            return model, tokenizer, None
        else:
            return None, None, self.gpt_client()
    
class ResponseGenerator(LanguageModel):
    def __init__(self, args):
        super().__init__(args)
        self.pipeline = None
        if not self.client:
            self.pipeline = transformers.pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                torch_dtype=torch.float16,
                device_map=self.device,
            )

    def prompt_format(self,query):
        prompt = f"In the following dialogue, there are two roles, Human and Assistant. You are required to play the role of Assistant and answer the Human's question to complete the conversation. Note that this is a one round conversation, just give your response once.\nHuman's question: {query} \n Assistant's answer:"
        if 'llama2' in self.args.infer_model:
            return prompt
        else: 
            message = [{"role": "user", "content": prompt}]
            return message
    

    def query_llama2_series(self, query):
        inputs = self.tokenizer([query], return_tensors="pt").to(self.device)
        n_input = [len(input_ids) for input_ids in inputs.input_ids][0]
        prompt = self.prompt_format(query)
        sequences = self.pipeline(
            prompt,
            do_sample=False,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            max_length=self.args.max_new_tokens,
        )
        for seq in sequences:
            assistant_response = seq['generated_text'].split("Assistant's answer:")[-1].strip()
        outputs = self.tokenizer([assistant_response], return_tensors="pt").to(self.device)
        n_generated = [len(output_ids) for output_ids in outputs.input_ids][0]
        return n_input, n_generated, assistant_response
    
    def query_qwen_series(self, query):
        raw_inputs = self.tokenizer([query], return_tensors="pt").to(self.device)
        n_input = [len(input_ids) for input_ids in raw_inputs.input_ids][0]
        message = self.prompt_format(query)
        text = self.tokenizer.apply_chat_template(
            message,
            tokenize=False,
            add_generation_prompt=True
        )
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)

        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=self.args.max_new_tokens,
            do_sample = False,
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        n_generated = [len(generated) for generated in generated_ids][0]
        assistant_response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return n_input, n_generated, assistant_response
    
    def query_gpt_series(self, query):
        message = self.prompt_format(query)
        max_retries = 3 
        for attempt in range(max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=gptMap[self.args.infer_model][1],
                    messages=message,
                    max_tokens=self.args.max_new_tokens,
                    logprobs=True,
                )
                response = completion.choices[0].message.content
                n_input = completion.usage.prompt_tokens
                n_generated = completion.usage.completion_tokens
                return n_input, n_generated, response
            except Exception as e:
                logging.error(f"Attempt {attempt + 1} failed: {e}")
        logging.error(f"All {max_retries} attempts failed for query: {query}")
        return 0, "Error: Unable to generate response after multiple attempts"
    
    def run(self,query):
        if 'llama2' in self.args.infer_model:
            n_input, n_generated, response = self.query_llama2_series(query)
        elif 'qwen' in self.args.infer_model:
            n_input, n_generated, response = self.query_qwen_series(query)
        elif 'gpt' in self.args.infer_model:
            n_input, n_generated, response = self.query_gpt_series(query)
        return n_input, n_generated, response


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Fairmaker Model parameters.')
    parser.add_argument('--mode', default='fm', choices=['baseline','ori','fm'], help='The query mode.')
    parser.add_argument('--bl_id', default=0, help='The index of baseline.')
    parser.add_argument('--bl', type=str, help='The baseline prompt.')
    parser.add_argument('--fw', default='fw', help='Whether save the respones of fairwords')
    parser.add_argument('--fm_model_path', default='/path/to/fairmaker/model', type=str, help='The fairmaker model path')
    parser.add_argument('--infer_model', default='/name/of/inference/model', type=str, help='The inference model type')
    parser.add_argument('--fairwords_file', default='/path/to/fairwords.json', type=str, help='The fairwords file')
    parser.add_argument('--query_file', default='/path/to/query.json', type=str, help='The query file')
    parser.add_argument('--query_type', default='test', type=str, help='Type of the query')
    parser.add_argument('--fm_gpu_id', default=1, type=int, help='The GPU ID')
    parser.add_argument('--infer_gpu_id', default=0, type=int, help='The GPU ID')
    parser.add_argument('--seed', default=42, type=int, help='The random seed')
    parser.add_argument('--max_new_tokens', default=512, type=int, help='Maximum number of tokens to generate')
    parser.add_argument('--out_dir', default='../results/GA-test', type=str, help='Directory for saving results')
    args = parser.parse_args()

    # Initialize logger
    save_dir = os.path.join(args.out_dir,f'{args.infer_model}/responses')
    os.makedirs(save_dir,exist_ok=True)
    logger = setup_logger(args,save_dir)
    log_args(args, logger)
    random_all(args.seed)
    generator = ResponseGenerator(args)

    with open(args.query_file, 'r') as file:
        queries = json.load(file)
    if args.mode == 'baseline':
        for idx, prompt in tqdm(enumerate(queries), desc="Processing queries", unit="query"):
            record = {}
            config = {
                'mode' : args.mode,
                'baseline_id' : args.bl_id,
                'infer_model' : args.infer_model,
                'seed' : args.seed,
            }
            start_time = time.time()
            input_prompt = args.bl + prompt
            n_input, n_generated, response = generator.run(input_prompt)
            generate_time = time.time() - start_time
            record['config'] = config
            record['idx'] = idx+400
            record['query'] = prompt
            record['input_query'] = input_prompt
            record['input_token_num'] = n_input
            record['response'] = response
            record['output_token_num'] = n_generated
            record['inference_time'] = generate_time
            out_file = f'{args.out_dir}/{args.infer_model}/responses/{args.mode}{args.bl_id}_responses_{args.infer_model}_seed{args.seed}.jsonl'
            try:
                with open(out_file, 'a') as file:
                    json.dump(record, file)
                    file.write("\n")
            except Exception as e:
                logger.info(f"Processing query {idx+400} failed: {e}")
                print("jsonl write error: ", e)
        logger.info(f"Script completed. Original responses output saved at {out_file}")
    if args.mode == 'ori':
        for idx, prompt in tqdm(enumerate(queries), desc="Processing queries", unit="query"):
            record = {}
            config = {
                'mode' : args.mode,
                'infer_model' : args.infer_model,
                'seed' : args.seed,
            }
            start_time = time.time()
            n_input, n_generated, response = generator.run(prompt)
            generate_time = time.time() - start_time
            record['config'] = config
            record['idx'] = idx+400
            record['query'] = prompt
            record['input_query'] = prompt
            record['input_token_num'] = n_input
            record['response'] = response
            record['output_token_num'] = n_generated
            record['inference_time'] = generate_time
            out_file = f'{args.out_dir}/{args.infer_model}/responses/{args.mode}_responses_{args.infer_model}_seed{args.seed}.jsonl'
            try:
                with open(out_file, 'a') as file:
                    json.dump(record, file)
                    file.write("\n")
            except Exception as e:
                logger.info(f"Processing query {idx+400} failed: {e}")
                print("jsonl write error: ", e)
        logger.info(f"Script completed. Original responses output saved at {out_file}")
    if args.mode == 'fm':
        fm_model, fm_tokenizer = get_model_and_tokenizer(args.fm_model_path, args.fm_gpu_id)
        # fm_response_generator = ResponseGenerator(args)
        for idx, prompt in tqdm(enumerate(queries), desc="Processing queries", unit="query"):
            start_time = time.time()
            fairwords, refined_prompt, n_refined_prompt = Fairmaker(args, fm_model, fm_tokenizer, prompt).run()
            refine_time = time.time() - start_time
            optimized_prompt = f"[{refined_prompt}] {prompt}"
            n_input, n_generated, response = generator.run(optimized_prompt)
            generate_time = time.time() - start_time
            #generate fm response and save
            record = {}
            config = {
                'mode' : args.mode,
                'infer_model' : args.infer_model,
                'seed' : args.seed,
            }
            record['config'] = config
            record['idx'] = idx+400
            record['query'] = prompt
            record['fairwords'] = fairwords
            record['fairmaker'] = refined_prompt
            record['n_fairmaker'] = n_refined_prompt
            record['refine_time'] = refine_time
            record['input_query'] = optimized_prompt
            record['input_token_num'] = n_input
            record['response'] = response
            record['output_token_num'] = n_generated
            record['inference_time'] = generate_time
            out_file = f'{args.out_dir}/{args.infer_model}/responses/{args.mode}_responses_{args.infer_model}_seed{args.seed}.jsonl'
            try:
                with open(out_file, 'a') as file:
                    json.dump(record, file)
                    file.write("\n")
            except Exception as e:
                print("jsonl write error: ", e)
            if args.fw == 'fw':
                #generate fw response and save
                record = {}
                config = {
                    'mode' : 'fw',
                    'infer_model' : args.infer_model,
                    'seed' : args.seed,
                }
                start_time = time.time()
                fw_prompt = f'[{fairwords}] {prompt}'
                n_input, n_generated, response = generator.run(fw_prompt)
                generate_time = time.time() - start_time
                record['config'] = config
                record['idx'] = idx+400
                record['query'] = prompt
                record['fairwords'] = fairwords
                record['input_query'] = fw_prompt
                record['input_token_num'] = n_input
                record['response'] = response
                record['output_token_num'] = n_generated
                record['inference_time'] = generate_time
                fw_out_file = f'{args.out_dir}/{args.infer_model}/responses/fw_responses_{args.infer_model}_seed{args.seed}.jsonl'
                try:
                    with open(fw_out_file, 'a') as file:
                        json.dump(record, file)
                        file.write("\n")
                except Exception as e:
                    print("jsonl write error: ", e)
        logger.info(f"Script completed. Fairmaker responses saved at {out_file}. Fairwords responses saved at {fw_out_file}")

