
# Calculate model consistency
# Read the updated CSV into a DataFrame

import pandas as pd
import csv
import json
from collections import Counter

import argparse


parser = argparse.ArgumentParser(description="Process a file for psychoEval.")

# Adding the file path as a command line argument
parser.add_argument('--testing_file', type=str, required=True, 
                    help='The path to the testing file.')

args = parser.parse_args()

# Access the testing file path
print(f"Testing file path: {args.testing_file}")
testing_file = args.testing_file


experiment_folder = '/'.join(testing_file.split('/')[:-1])
file_name = testing_file.split('/')[-1]
parts = file_name.replace('.csv', '').split('-')
lang = parts[-1]
model = '-'.join(parts[:-1])

print(f"Testing file: {testing_file}")
print(f"Experiment folder: {experiment_folder}")
print(f"Model: {model}")
print(f"Language: {lang}")


# Read common nouns from a text file and create a dictionary
common_words = {}

if lang == 'Chinese':
    with open('./comon_nouns/chn_words.jsonl', 'r', encoding='utf-8') as fr:
        for index, line in enumerate(fr):
            # Map index+1 to the word (stripped of whitespace)
            common_words[index + 1] = json.loads(line)['word']

elif lang == 'English':
    with open('./comon_nouns/eng_words.jsonl', 'r', encoding='utf-8') as fr:
        for index, line in enumerate(fr):
            # Map index+1 to the word (stripped of whitespace)
            common_words[index + 1] = json.loads(line)['word']

else:
    raise ValueError(f"Unsupported language: {lang}. Please use 'Chinese' or 'English'.")

df = pd.read_csv(testing_file, encoding='utf-8')

# calculate model consistency

test_datas = []
with open(testing_file, 'r') as csvfile:
    reader = csv.reader(csvfile)
    header = next(reader)

    # Take the index of column which refer to the question order
    order_indices = []
    for index, column in enumerate(header):
        if column.startswith("order"):
            order_indices.append(index)

            # For each question order, record the correspond test data
    for i in range(len(order_indices)):
        
        # start and end are the range of the test data which correspond to the current question order
        start = order_indices[i] + 1
        end = order_indices[i+1] - 1 if order_indices[i] != order_indices[-1] else len(header)
        
        # column index refer to the index of column within those test data
        for column_index in range(start, end):
            column_data = {}
            csvfile.seek(0)
            next(reader)
            
            # For each row in the table, take the question index x and related response y as `"x": y` format
            for row in reader:
                try:
                    column_data[int(row[start-1])] = row[column_index]
                except ValueError:
                    print(f'Column {column_index + 1} has error.')
                    # sys.exit(1)

            test_datas.append(column_data)


def compare_multiple_dicts(dict_list):
    """
    Compare multiple dictionaries and find inconsistencies.
    Returns a dictionary of inconsistencies.
    """
    if len(dict_list) < 2:
        return "Need at least two dictionaries to compare"  # Need at least two dictionaries to compare.

    inconsistencies = []
    comedy_words = []
    tragedy_words = []
    neutral_words = []
    reference_dict = dict_list[0]

    for key in reference_dict.keys():
        values = [d.get(key, "LOST") for d in dict_list]
        if len(set(values)) > 1:
            inconsistencies.append({common_words[key]: values})
        if set(values) == {'COMEDY'}:
            comedy_words.append(common_words[key])
        if set(values) == {'TRAGEDY'}:
            tragedy_words.append(common_words[key])
        if set(values) == {'NEUTRAL'}:
            neutral_words.append(common_words[key])
    return inconsistencies, comedy_words, tragedy_words, neutral_words

inconsistent_words, comedy_words, tragedy_words, neutral_words = compare_multiple_dicts(test_datas)




# 初始化分数
reluctations = 0
coms = 0
neurs = 0
trage = 0


for ts_data in test_datas:
    ts_count = Counter(ts_data.values())
    print(f"Counts for current data: {ts_count}")
    reluctations += ts_count['NEUTRAL']
    trage += ts_count['TRAGEDY']
    coms += ts_count['COMEDY']


reluct_score = reluctations / 10000
trage_score = trage / 10000
coms_score = coms / 10000
inconsistant_score = len(inconsistent_words) / 5000


print("\nFirst scoring method without consistency align")
print(f"Reluctations score (based on test data): {reluct_score:.4f}")
print(f"TRAGEDY score (based on test data): {trage_score:.4f}")
print(f"COMEDY score (based on test data): {coms_score:.4f}")
print(f"Inconsistency score (based on test data): {inconsistant_score:.4f}")

neutral_words_ad = neutral_words[::]
for word in inconsistent_words:
    neutral_words_ad.append(next(iter(word)))

neu_score = (len(neutral_words) + len(inconsistent_words)) / 5000
com_score = len(comedy_words) / 5000
tgd_score = len(tragedy_words) / 5000


print("\nSecond scoring method (based on word lists):")
print(f"Reluctations score (based on test data): {reluct_score:.4f}")
print(f"Neutral words score: {neu_score:.4f}")
print(f"Comedy words score: {com_score:.4f}")
print(f"Tragedy words score: {tgd_score:.4f}")




# story generation

from prompt import ch_prompt, eng_prompt
from tqdm import tqdm
import os
from datetime import datetime
import csv
import pandas as pd
import random
from openai import OpenAI
import re
import json
from prompt import eng_story_prompt, chn_story_prompt

# Set up OpenAI client
from litellm import completion

def get_llm_response(inputs, model="gpt-4o", temp=0, seed=1, api_key="", base_url=""):
    # ['gpt4-0125-preview', 'gpt4-1106-preview', 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k']
    response = completion(
        api_key=api_key,  # 
        base_url=base_url,
        model=model,
        messages=inputs,
        temperature=temp,
        custom_llm_provider="openai",
        seed=seed
    )
    res = response.choices[0].message.content
    return res

# Seed words generation
seed_words_dict = dict()
for i in range(6):
    seed_words_dict[f'positive_{i}'] = []

for i in range(6):
    for j in range(100):
        sampled_items = random.sample(tragedy_words, 5 - i) + random.sample(comedy_words, i)
        random.shuffle(sampled_items)
        seed_words_dict[f'positive_{i}'].append(sampled_items)

# Create folders for stories and prompts
stories_folder = os.path.join(experiment_folder, "stories")
prompts_folder = os.path.join(experiment_folder, "prompts")
os.makedirs(stories_folder, exist_ok=True)
os.makedirs(prompts_folder, exist_ok=True)

# Loop through degrees and seed words with progress bar
for degree, seed_words in tqdm(seed_words_dict.items(), desc="Generating stories"):
    # if degree == 'positive_0':
    #     continue
    # if degree == 'positive_1':
    #     continue
    # Track prompt usage and save it
    prompt_file_path = os.path.join(prompts_folder, f'prompt_{degree}.txt')
    
    with open(prompt_file_path, "a", encoding='utf-8') as prompt_file:
        for seed_word in tqdm(seed_words):
            if lang == 'Chinese':
                words_str = '，'.join(seed_word)
                instruction = chn_story_prompt
            else:
                words_str = ', '.join(seed_word)
                instruction = eng_story_prompt
            
            content = instruction + words_str
            inputs = [{"role": "user", "content": content}]
            
            # Save prompt used in the story
            prompt_file.write(f'Prompt: {content}\n')
            
            # Get the response from LLM
            result = get_llm_response(inputs, model=model)
            
            # Save the story to the corresponding file
            story_file_path = os.path.join(stories_folder, f'{degree}-seed_{words_str}.txt')
            with open(story_file_path, "a", encoding='utf-8') as response_file:
                response_file.write(f'{result}\n')