import json
import argparse
from vllm import LLM, SamplingParams
from tqdm import tqdm
import os
import csv
import pandas as pd
from collections import Counter
from litellm import completion
from eval import evaluation_pvqrr
from utils import get_circle_list
import jinja2
import re
import random
import gc
import torch
from vllm.lora.request import LoRARequest

os.environ["OPENAI_API_KEY"] = "token-abc123"

def load_person_info(file_path):
    with open(file_path, 'r') as file:
        person_info = json.load(file)
    return person_info

def load_questions(json_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

def load_chat_template(template_path):
    with open(template_path, 'r', encoding='utf-8') as f:
        template_content = f.read()
    return jinja2.Template(template_content)

def get_chat_template(model_path):
    model_name = model_path.lower()
    template_dir = os.path.join("Files", "chat_templates")
    
    if "qwen" in model_name:
        template_path = os.path.join(template_dir, "chat_template_qwen.jinja")
    elif "llama" in model_name and ("3" in model_name or "-3" in model_name):
        template_path = os.path.join(template_dir, "chat_template_llama3.jinja")
    
    return load_chat_template(template_path)

def create_chat_prompt(question, template, role_description):
    prompt_text = f"USER: Let's role-play. I will ask you a question and you must give me an answer. I want you to act as the person described below. Think from the person's perspective. \n {role_description}\n\nQuestion: {question}"
    
    messages = [
        {
            "role": "user",
            "content": prompt_text
        }
    ]
    
    try:
        prompt = template.render(
            messages=messages,
            add_generation_prompt=True,
            tools=None,
            bos_token=""
        )
        prompt = re.sub(r'\n\s*\n', '\n\n', prompt.strip())
        return prompt
    except Exception as e:
        return prompt_text

def setup_vllm(model_path, gpu_id, lora_path=None):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    model = LLM(model=model_path, gpu_memory_utilization=0.9, enable_lora=bool(lora_path), max_lora_rank=256)
    
    model_name = model_path.lower()
    is_llama3 = "llama" in model_name and ("3" in model_name or "-3" in model_name)
    
    sampling_params_kwargs = {
        "temperature": 0,
        "top_p": 0.95,
        "max_tokens": 500,
        "stop_token_ids": [2]
    }
    
    if is_llama3:
        sampling_params_kwargs["stop"] = ["<|eot_id|>", "<|end_of_text|>"]
    
    sampling_params = SamplingParams(**sampling_params_kwargs)
    
    return model, sampling_params

def collect_responses_per_role(model_path, gpu_id, questions, selected_persons, lora_path=None):
    all_responses = {}
    try:
        model, sampling_params = setup_vllm(model_path, gpu_id, lora_path)
        sampling_params.n = 1
        template = get_chat_template("Qwen2.5-7B-Instruct")
        lora_request = None
        if lora_path:
            lora_request = LoRARequest(lora_name="lora1", lora_int_id=1, lora_path=lora_path)
        for person_idx, person in enumerate(tqdm(selected_persons, desc="process", ncols=100)):
            person_id = f"person_{person_idx}"
            responses_for_person = []
            prompts = []
            prompt_mapping = []
            for q_idx, item in enumerate(questions):
                question = item['question']
                prompt = create_chat_prompt(question, template, person['description'])
                prompts.append(prompt)
                prompt_mapping.append(q_idx)
            try:
                outputs = model.generate(prompts, sampling_params, lora_request=lora_request)
                responses = [None] * len(questions)
                for i, output in enumerate(outputs):
                    q_idx = prompt_mapping[i]
                    response_text = output.outputs[0].text.strip()
                    response_item = {
                        'question_id': questions[q_idx]['question_id'],
                        'question': questions[q_idx]['question'],
                        'model_response': response_text,
                        'original_response': questions[q_idx].get('response', ''),
                        'centerpiece': questions[q_idx].get('centerpiece', ''),
                        'options': questions[q_idx].get('options', [])
                    }
                    responses[q_idx] = response_item
                all_responses[person_id] = {
                    'person': person,
                    'responses': responses
                }
    finally:
        if 'model' in locals():
            try:
                if hasattr(model, 'llm_engine'):
                    model.llm_engine.model_executor.driver_worker.model_runner.model.cpu()
                    del model.llm_engine
                if hasattr(model, 'model_executor'):
                    del model.model_executor
            except Exception as e:
                print(f"wrong: {str(e)}")
            
            del model
            
            if 'sampling_params' in locals():
                del sampling_params
            
            gc.collect()
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
                torch.cuda.reset_accumulated_memory_stats()
                torch.cuda.synchronize()
                
                try:
                    torch.cuda.ipc_collect()
                except:
                    pass
                
                gc.collect()
                torch.cuda.empty_cache()
                
                print(f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB")
                print(f"{torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    return all_responses

def extract_answer(origin_response, response, options, origin_ques, use_cycle=False, api_base=None):
    results = []
    letters = []
    if use_cycle:
        new_options = get_circle_list(options)
    else:
        new_options = [options]
    i = 0
    letter2int = {
        'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5
    }
    
    for op in tqdm(new_options, desc="choose", leave=False):
        extract_answer_example = f"""
        Question: {origin_ques}
        There is a person X whose answer is {origin_response}.
        There is a person Y whose answer is {response}.
        Disregarding the fact that X has not provided a detailed explanation, please assess the similarity between Y's and X's responses. Below are the options for similarity:
        Use the above information to determine which option B's response is more consistent with A's response.
        Options: A. {op[0]} B. {op[1]} C. {op[2]} D. {op[3]} E. {op[4]} F. {op[5]}
        Please give me the choice letter first and then give me the reason.
        """
        messages = []
        messages.append({"role":"user", "content": extract_answer_example})
        try:
            answer = completion(
                model="openai/Qwen2.5-32B-Instruct",
                api_base=api_base,
                messages=messages,
                temperature=0
            ).choices[0].message.content
            
            results.append(answer)
            if answer[0] in letter2int.keys():
                letters.append((letter2int[answer[0]]+i)%6)
            else:
                letters.append('')
        except Exception as e:
            print(f"wrong: {str(e)}")
            results.append("")
            letters.append('')
        i+=1
    return results, letters

def evaluate_responses_for_person(person_responses, api_base, use_cycle=False):
    evaluation_results = []
    
    for response in tqdm(person_responses, desc="eval", leave=False):
        results, letters = extract_answer(
            response['original_response'],
            response['model_response'],
            response['options'],
            response['question'],
            use_cycle=use_cycle,
            api_base=api_base
        )
        
        evaluation_results.append({
            'question_id': response['question_id'],
            'results': results,
            'letters': letters
        })
    
    return evaluation_results

def generate_answers_for_person(evaluation_results):
    num2letter = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
    final_answers = []
    
    for eval_result in evaluation_results:
        all_letters = eval_result['letters']
        
        count_dict = Counter(all_letters)
        if count_dict:
            max_count = max(count_dict.values())
            max_numbers = [num for num, count in count_dict.items() if count == max_count]
            answer = max_numbers[0] if max_numbers else 0
        else:
            answer = 0
        
        final_answer = {
            'question_id': eval_result['question_id'],
            'respond_answer': num2letter[answer]
        }
        final_answers.append(final_answer)
    
    return final_answers

def evaluate_person_score(final_answers):
    try:
        evaluation_result = evaluation_pvqrr(final_answers)
        return evaluation_result['level3']
    except Exception as e:
        print(f"wrong: {str(e)}")
        return 0

def save_results_to_csv(results, output_file):
    df = pd.DataFrame(results)
    df.to_csv(output_file, index=False, encoding='utf-8')
    print(f"{output_file}")

def create_result_directory(model_name, use_cycle, num_roles, lora_path):
    cycle_str = "cycle" if use_cycle else "nocycle"
    dir_name = f"{model_name}_roles{num_roles}_{cycle_str}"
    if lora_path:
        dir_name = f"{dir_name}_{lora_path.split('/')[-1]}"
    result_dir = os.path.join(os.getcwd(), "lora_results", dir_name)
    
    os.makedirs(result_dir, exist_ok=True)
    print(f"{result_dir}")
    
    return result_dir

def main():
    parser = argparse.ArgumentParser(description='role eval')
    parser.add_argument('--input', type=str, default="data.json")
    parser.add_argument('--person_file', type=str, default="data_person.json")
    parser.add_argument('--model_path', type=str, default="Llama-3.1-8B-Instruct")
    parser.add_argument('--gpu_id', type=str, default="7")
    parser.add_argument('--api_base', type=str, default="http://localhost:8820/v1")
    parser.add_argument('--num_roles', type=int, default=100)
    parser.add_argument('--use_cycle', action='store_true')
    parser.add_argument('--lora_path', type=str, default=None) 
    args = parser.parse_args()

    model_name = os.path.basename(args.model_path.rstrip('/'))
    
    result_dir = create_result_directory(model_name, args.use_cycle, args.num_roles, args.lora_path)
    
    person_info = load_person_info(args.person_file)
    
    selected_persons = random.sample(person_info, min(args.num_roles, len(person_info)))
    

    questions = load_questions(args.input)
    
    all_responses = collect_responses_per_role(
        model_path=args.model_path,
        gpu_id=args.gpu_id,
        questions=questions,
        selected_persons=selected_persons,
        lora_path=args.lora_path 
    )
    
    results = []
    
    for person_id, person_data in tqdm(all_responses.items(), desc="role eval", ncols=100):
        person = person_data['person']
        person_responses = person_data['responses']
        
        evaluation_results = evaluate_responses_for_person(
            person_responses,
            api_base=args.api_base,
            use_cycle=args.use_cycle
        )
        
        final_answers = generate_answers_for_person(evaluation_results)
        
        score = evaluate_person_score(final_answers)
        
        result = {
            'person_id': person_id,
            'description': person['description'],
            'score': score
        }
        results.append(result)
        
        person_detailed_result = {
            'person_id': person_id,
            'description': person['description'],
            'responses': person_responses,
            'evaluation_results': evaluation_results,
            'final_answers': final_answers,
            'score': score
        }
        
        person_result_file = os.path.join(result_dir, f"{person_id}_details.json")
        with open(person_result_file, 'w', encoding='utf-8') as f:
            json.dump(person_detailed_result, f, ensure_ascii=False, indent=4)
    
    csv_output_file = os.path.join(result_dir, "role_scores.csv")
    save_results_to_csv(results, csv_output_file)
    
    experiment_params = {
        'model_path': args.model_path,
        'gpu_id': args.gpu_id,
        'api_base': args.api_base,
        'num_roles': args.num_roles,
        'use_cycle': args.use_cycle,
        'input_file': args.input,
        'person_file': args.person_file
    }
    
    params_file = os.path.join(result_dir, "experiment_params.json")
    with open(params_file, 'w', encoding='utf-8') as f:
        json.dump(experiment_params, f, ensure_ascii=False, indent=4)

if __name__ == "__main__":
    main() 