import argparse
import os
from datetime import datetime
import csv
import pandas as pd
import random
import re
import json
from tqdm import tqdm

# Assuming you have a prompt.py file that contains ch_prompt and eng_prompt
from prompt import ch_prompt, eng_prompt

# Import litellm.completion or OpenAI client, depending on your setup
from litellm import completion

def get_llm_response(inputs, model="gpt-4o", temp=0, seed=1, api_key=None, base_url=None):
    response = completion(
        api_key=api_key,
        base_url=base_url,
        model=model,
        messages=inputs,
        temperature=temp,
        custom_llm_provider="openai",
        seed=seed
    )
    res = response.choices[0].message.content
    return res

def create_experiment_folder(model, context_num, test_count, base_path="/proposed_method/comon_nouns/exp_results"):

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_folder = os.path.join(base_path, f"{timestamp}_M{model}_TC{test_count}_CN{context_num}")
    os.makedirs(experiment_folder, exist_ok=True)
    return experiment_folder

def generate_testfile(test_count, testing_file, instruction_prompt, common_words, repeat_test):

    output_file = testing_file
    csv_output = []
    word_list = common_words

    for shuffle_count in range(test_count):
        question_indices = list(word_list.keys())

        # Shuffle question indices
        random.shuffle(question_indices)

        # Create questions based on shuffled indices
        questions = [f'{index}. {word_list[question]}' for index, question in enumerate(question_indices, 1)]
        # Add instruction prompt and questions to csv_output
        csv_output.append([f'Prompt: {instruction_prompt}'] + questions)
        # Add the order of questions
        csv_output.append([f'order-{shuffle_count}'] + question_indices)
        # Add placeholders for responses
        for count in range(repeat_test):
            csv_output.append([f'shuffle{shuffle_count}-test{count}'] + [''] * len(question_indices))

    # Transpose csv_output so each row corresponds to a column in the CSV file
    csv_output = list(zip(*csv_output))

    # Write to CSV file
    with open(output_file, 'w', newline='', encoding='utf-8') as csv_file:
        writer = csv.writer(csv_file)
        writer.writerows(csv_output)

def extract_items(text):

    pattern = r'^\d+.*$'

    results = []
    for line in text.splitlines():
        match = re.match(pattern, line.strip(), re.IGNORECASE)
        if match:
            results.append(f"{line}")
    return "\n".join(results)

def convert_results(result, column_header):

    result = extract_items(result)
    pattern = r"喜剧|悲剧|comedy|tragedy"
    result_list = []
    for element in result.split('\n'):
        if element.strip():
            match = re.search(pattern, element, re.IGNORECASE)
            if match:
                if re.search('comedy', element, re.IGNORECASE) or re.search('喜剧', element, re.IGNORECASE):
                    result_list.append('COMEDY')
                elif re.search('tragedy', element, re.IGNORECASE) or re.search('悲剧', element, re.IGNORECASE):
                    result_list.append('TRAGEDY')
                else:
                    result_list.append('NEUTRAL')
            else:
                result_list.append('NEUTRAL')  # Add 'NEUTRAL' if no match is found
    return result_list

def parse_arguments():
    parser = argparse.ArgumentParser(description='Run LLM experiments with different parameters.')
    parser.add_argument('--model', type=str, default='gpt-4o', help='Model name to use.')
    parser.add_argument('--context_num', type=int, default=30, help='Context length.')
    parser.add_argument('--test_count', type=int, default=2, help='Number of tests to run.')
    parser.add_argument('--lang', type=str, choices=['Chinese', 'English'], default='English', help='Language to use.')
    parser.add_argument('--repeat_test', type=int, default=1, help='Number of times to repeat the test.')

    return parser.parse_args()

def main():
    args = parse_arguments()

    # Set experiment parameters
    model = args.model
    context_num = args.context_num
    test_count = args.test_count
    lang = args.lang
    repeat_test = args.repeat_test

    common_words = {}

    if lang == 'Chinese':
        with open('./comon_nouns/chn_words.jsonl', 'r', encoding='utf-8') as fr:
            for index, line in enumerate(fr):
                common_words[index + 1] = json.loads(line)['word']
        instruction = ch_prompt
    elif lang == 'English':
        with open('./comon_nouns/eng_words.jsonl', 'r', encoding='utf-8') as fr:
            for index, line in enumerate(fr):
                common_words[index + 1] = json.loads(line)['word']
        instruction = eng_prompt
    else:
        raise ValueError(f"Unsupported language: {lang}. Please use 'Chinese' or 'English'.")

    experiment_folder = create_experiment_folder(model=model, context_num=context_num, test_count=test_count)

    print(f"Experiment folder: {experiment_folder}")

    # Generate test file (CSV)
    testing_file = os.path.join(experiment_folder, f"{model}-{lang}.csv")
    generate_testfile(
        test_count=test_count,
        testing_file=testing_file,
        instruction_prompt=instruction,
        common_words=common_words,
        repeat_test=repeat_test
    )

    total_iterations = test_count * repeat_test

    df = pd.read_csv(testing_file, encoding='utf-8')

    order_columns = [col for col in df.columns if col.startswith("order")]

    CONTEXT_LEN = context_num

    prompts_folder = os.path.join(experiment_folder, "prompts")
    responses_folder = os.path.join(experiment_folder, "responses")
    os.makedirs(prompts_folder, exist_ok=True)
    os.makedirs(responses_folder, exist_ok=True)

    with tqdm(total=total_iterations) as pbar:
        for shuffle_count, order_col in enumerate(order_columns):
            questions_column_index = df.columns.get_loc(order_col) - 1

            word_list = df.iloc[:, questions_column_index].astype(str)

            separated_questions = [
                word_list[j:j + CONTEXT_LEN]
                for j in range(0, len(word_list), CONTEXT_LEN)
            ]

            word_list_chunks = [
                '\n'.join([f"{idx + 1}.{q.split('.', 1)[1]}" for idx, q in enumerate(questions)])
                for questions in separated_questions
            ]
            print(f" {len(word_list_chunks)} rounds")

            for k in range(repeat_test):
                column_header = f'shuffle{shuffle_count}-test{k}'

                result_string_list = []
                for questions_string in tqdm(word_list_chunks):
                    inputs = [
                        {"role": "user", "content": instruction + '\n' + questions_string}
                    ]
                    result = get_llm_response(inputs, model=model, api_key=api_key, base_url=base_url)

                    result_string_list.append(result.strip())

                    with open(f'{prompts_folder}/{model}-shuffle{shuffle_count}.txt', "a", encoding='utf-8') as prompt_file:
                        prompt_file.write(f'{inputs}\n====\n')
                    with open(f'{responses_folder}/{model}-shuffle{shuffle_count}.txt', "a", encoding='utf-8') as response_file:
                        response_file.write(f'{result}\n====\n')

                result_string = '\n'.join(result_string_list)

                result_list = convert_results(result_string, column_header)

                if column_header in df.columns:
                    df[column_header] = result_list
                else:
                    df.insert(
                        loc=questions_column_index + insert_count + 2,
                        column=column_header,
                        value=result_list
                    )
                    insert_count += 1

                pbar.update(1)

    df.to_csv(testing_file, index=False, encoding='utf-8')

if __name__ == "__main__":
    main()