from typing import Dict, List, Optional
import openai
import os
import re
import pandas as pd
import numpy as np

simulation_version = "v2"
validation_per_model = dict()

def create_client():
    return openai.OpenAI()

def get_conversation_history(user_data, round_number, event_order, sender_id, recipient_id, simulation_version, isLLM: bool):
    """
    Get the conversation history for a given round and event order

    Args:
        user_data: The user data dictionary
        round_number: The round number
        event_order: The event order

    Returns:
        The conversation history
    """
    if simulation_version == "v2":
        message_column_name = "text"
    else:
        if isLLM:
            message_column_name = "llm_text"
        else:
            message_column_name = "text"

    tweet_mask = user_data[(user_data["chat_round_order"] == round_number) & ((user_data["sender_id"] == sender_id) | (user_data["recipient_id"] == sender_id)) & (user_data['event_type'] == 'tweet')]
    round_sender_mask = user_data[(user_data["chat_round_order"] == round_number) & ((user_data["sender_id"] == sender_id) | (user_data["recipient_id"] == sender_id)) & (user_data['event_type'] == 'message_sent')]
    conversation = round_sender_mask[(round_sender_mask["event_order"] <= event_order)]
    conversation_history = []

    for index, tweet in tweet_mask.iterrows():
        if tweet["sender_id"] == sender_id:
            conversation_history.append(f"My Tweet: {tweet['text']}")
        else:
            conversation_history.append(f"{recipient_id}'s Tweet: {tweet['text']}")

    for index, message in conversation.iterrows():
        if pd.isna(message[message_column_name]) or message[message_column_name] is None or message[message_column_name].strip() == '':
            continue
        if message["sender_id"] == sender_id:
            conversation_history.append(f"My Response: {message[message_column_name]}")
        else:
            conversation_history.append(f"{recipient_id}'s Response: {message[message_column_name]}")
    return "\n".join(conversation_history)

def validate_response(
    topic: str,
    conversation_history: List[Dict[str, str]],
    client: Optional[openai.OpenAI] = None,
    last_message: str = None
) -> Dict[str, any]:
    """
    Validates a response using GPT-4o-mini based on the topic and conversation history.
    
    Args:
        response: The response to validate
        topic: The main topic or context of the conversation
        conversation_history: List of previous messages in the conversation
        criteria: Optional list of specific validation criteria
        client: Optional OpenAI client instance
        
    Returns:
        Dictionary containing validation results
    """
    if client is None:
        client = create_client()

    topic_in_file = ['Regular fasting will improve your health', 'Regular fasting will not improve your health']
    if topic in topic_in_file:
        current_or_other = "the"
    else:
        current_or_other = "another"

    system_prompt_template = os.path.join('..', '..', "prompts", "invalid_response", "system_message.md")
    with open(system_prompt_template, 'r') as f:
        system_prompt = f.read().format(
            TOPIC=topic,
            CURRENT_OR_OTHER=current_or_other
        )

    user_prompt_template = os.path.join('..', '..', "prompts", "invalid_response", "user_message.md")
    with open(user_prompt_template, 'r') as f:
        user_prompt = f.read().format(
            CHAT_HISTORY=conversation_history,
            LAST_MESSAGE=last_message
        )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]

    response = client.chat.completions.create(
        model="gpt-4o-mini-2024-07-18",
        messages=messages,
        temperature=0.1
    )

    # Process and structure the validation response
    validation_text = response.choices[0].message.content
    validity_match = re.search(r'(Valid|Invalid|INVALID|VALID)', validation_text)
    validation_status = validity_match.group(1) if validity_match else "Unknown"
    
    return validation_status, validation_text, messages

def apply_validation(df: pd.DataFrame, data_prefix: str, topic: str, client, simulation_version: str, isLLM: bool):
    agents_v2 = df["sender_id"].dropna().unique().tolist()
    if simulation_version == "v2":
        message_column_name = "text"
        if isLLM:
            last_message_column_name = "llm_text"
        else:
            last_message_column_name = "text"
    else:
        if isLLM:
            message_column_name = "llm_text"
            last_message_column_name = "llm_text"
        else:
            message_column_name = "text"
            last_message_column_name = "text"

    for index, row in df.iterrows():
        if row["event_type"] != "message_sent":
            continue
        if pd.isna(row[message_column_name]) or row[message_column_name] is None or row[message_column_name].strip() == '':
            continue
        if pd.isna(row[last_message_column_name]) or row[last_message_column_name] is None or row[last_message_column_name].strip() == '':
            continue
        event_order = row["event_order"]
        sender_id = row["sender_id"]
        recipient_id = row["recipient_id"]
        round_number = row["chat_round_order"]
        conversation_history = get_conversation_history(df, round_number, event_order, sender_id, recipient_id, simulation_version, isLLM)
        last_message = row[last_message_column_name]
        validity, validation_text, messages = validate_response(topic, conversation_history, client, last_message)
        df.at[index, "validity"] = validity
        df.at[index, "validity_reason"] = validation_text
        df.at[index, "input_prompt_validation"] = str(messages)

    return df

def get_metrics(df: pd.DataFrame, isSplit: bool = False, split_type: str = "round", data_prefix: str = None, isLLM: bool = False):
    try:
        isTrain = False
        invalid_ratio = None
        invalid_ratio_train = None
        invalid_ratio_test = None
        if isLLM:
            if isSplit:
                if split_type == "round":
                    valid_responses_train = df[(df['validity'].str.upper() == 'VALID') & (df['chat_round_order'] != 3)].shape[0]
                    valid_responses_test = df[(df['validity'].str.upper() == 'VALID') & (df['chat_round_order'] == 3)].shape[0]

                    invalid_responses_train = df[(df['validity'].str.upper() == 'INVALID') & (df['chat_round_order'] != 3)].shape[0]
                    invalid_responses_test = df[(df['validity'].str.upper() == 'INVALID') & (df['chat_round_order'] == 3)].shape[0]
                else:
                    split_data_path_train = os.path.join('..', '..', 'data', 'finetune_data', f'{split_type}_split_data', 'train')
                    split_data_path_test = os.path.join('..', '..', 'data', 'finetune_data', f'{split_type}_split_data', 'test')
                    if any(data_prefix == file[:-4] for file in os.listdir(split_data_path_train)):
                        valid_responses_train = df[df['validity'].str.upper() == 'VALID'].shape[0]
                        invalid_responses_train = df[df['validity'].str.upper() == 'INVALID'].shape[0]
                        isTrain = True
                    elif any(data_prefix == file[:-4] for file in os.listdir(split_data_path_test)):
                        valid_responses_test = df[df['validity'].str.upper() == 'VALID'].shape[0]
                        invalid_responses_test = df[df['validity'].str.upper() == 'INVALID'].shape[0]
                    else:
                        print(f"Data prefix {data_prefix} not found in {split_data_path_train} or {split_data_path_test}")
                        return invalid_ratio, invalid_ratio_train, invalid_ratio_test

            if split_type == "round":
                total_responses = valid_responses_train + invalid_responses_train + valid_responses_test + invalid_responses_test
                total_responses_train = valid_responses_train + invalid_responses_train
                total_responses_test = valid_responses_test + invalid_responses_test

                invalid_ratio = invalid_responses_train / total_responses_train if total_responses_train > 0 else 0.0
                invalid_ratio_train = invalid_responses_train / total_responses_train if total_responses_train > 0 else 0.0
                invalid_ratio_test = invalid_responses_test / total_responses_test if total_responses_test > 0 else 0.0
            else:
                if isTrain:
                    total_responses = valid_responses_train + invalid_responses_train
                    invalid_ratio_train = invalid_responses_train / total_responses if total_responses > 0 else 0.0
                else:
                    total_responses = valid_responses_test + invalid_responses_test
                    invalid_ratio_test = invalid_responses_test / total_responses if total_responses > 0 else 0.0
        else:

            invalid_responses = df[df['validity'].str.upper() == 'INVALID'].shape[0]
            valid_responses = df[df['validity'].str.upper() == 'VALID'].shape[0]
            total_responses = invalid_responses + valid_responses
            invalid_ratio = invalid_responses / total_responses if total_responses > 0 else 0.0

        return invalid_ratio, invalid_ratio_train, invalid_ratio_test
    except Exception as e:
        print(f"Error getting metrics: {e}")
        return None, None, None

def main(simulation_eval_path: str, get_metrics_only: bool, isLLM: bool, isSplit: bool = False, split_type: str = "round", simulation_version: str = "v0"):
    if not os.getenv("OPENAI_API_KEY"):
        with open("openai-key.txt", "r") as f:
            os.environ["OPENAI_API_KEY"] = f.read().strip()

    client = create_client()
    metrics_df = pd.DataFrame()
    human_ratio_dict = dict()
    human_ratio_dict['invalid_ratio'] = []

    if isSplit:
        split_data_path = os.path.join('..', '..', 'data', 'finetune_data', f'{split_type}_split_data')
        human_ratio_dict['invalid_ratio_train'] = []
        human_ratio_dict['invalid_ratio_test'] = []

    breadth_dir = os.path.join("../../data", "raw_data", "phase_2_breadth_topics")
    breadth_set = set()
    for root, dirs, files in os.walk(breadth_dir):
        for filename in files:
            if filename.endswith('.csv'):
                # Remove _0.0.1 before .csv if it exists
                base_name = filename.replace('_0.0.1', '').replace('.csv', '')
                breadth_set.add(base_name)
    
    count = 0
    if not get_metrics_only:
        for data_prefix in os.listdir(simulation_eval_path):
            filename_pattern = r"2025(03|04|05|06).*"
            if not re.match(filename_pattern, data_prefix):
                print(f"Skipping {data_prefix} because it doesn't match the pattern")
                continue
            if data_prefix not in breadth_set:
                continue
            for model in os.listdir(os.path.join(simulation_eval_path, data_prefix)):
                file_path = os.path.join(simulation_eval_path, data_prefix, model, f'simulation-{simulation_version}.csv')
                if model != "gpt-4o-mini-2024-07-18":  # ft:Llama-3.1-8B-Instruct-SFT-20250711:group-5epochs, ft:Llama-3.1-8B-Instruct-SFT-20250710:round-5epochs, ft:Llama-3.1-8B-Instruct-SFT-20250710:topic-5epochs
                    continue
                if not os.path.exists(file_path):
                    print(f"Skipping {file_path} because it doesn't exist")
                    continue
                # Check if output file already exists to avoid reprocessing
                output_dir = os.path.join('..', '..', 'result', 'eval', 'validity', data_prefix)
                if isLLM:
                    output_file = os.path.join(output_dir, model, f"simulation-llm-{simulation_version}.csv")
                else:
                    output_file = os.path.join(output_dir, f"simulation-human.csv")
                if os.path.exists(output_file):
                    print(f"Skipping {file_path} because output file {output_file} already exists")
                    continue
                
                df = pd.read_csv(file_path)
                print(f"Processing {file_path}")
                topic = re.search(r'\d{8}_\d{6}_(.*)_.{26}', data_prefix).group(1).replace('_', ' ')
                topic = re.sub(r' +', ' ', topic)

                df = apply_validation(df, data_prefix, topic, client, simulation_version, isLLM)
                os.makedirs(os.path.join(output_dir, model), exist_ok=True)
                if isLLM:
                    df.to_csv(os.path.join(output_dir, model, f"simulation-llm-{simulation_version}.csv"), index=False)
                    print(f"Saved {os.path.join(output_dir, model, f'simulation-llm-{simulation_version}.csv')}")
                else:
                    if model == "gpt-4o-mini-2024-07-18":
                        df.to_csv(os.path.join(output_dir, f"simulation-human.csv"), index=False)
                        print(f"Saved {os.path.join(output_dir, f'simulation-human.csv')}")
                        break
            count += 1


    
    output_dir = os.path.join('..', '..', 'result', 'eval', 'validity')
    for data_prefix in os.listdir(output_dir):
        if data_prefix not in breadth_set:
            # print(f"Skipping {data_prefix} because it's a breadth topic")
            continue
        if data_prefix.endswith('.csv'):
            continue
        if "20250331_202343" in data_prefix or "20250403_191749" in data_prefix or "20250418_185529" in data_prefix or '20250430_154612' in data_prefix or '20250428_001616' in data_prefix or '20250422_040204' in data_prefix or '20250429_162344' in data_prefix or '20250424_220026' in data_prefix:
            continue
        for model in os.listdir(os.path.join(output_dir, data_prefix)):
            if model.endswith('.csv'):
                continue
            if not isLLM:
                if model != "gpt-4o-mini-2024-07-18":
                    continue
                human_file_path = os.path.join(output_dir, data_prefix, f'simulation-human.csv')
                if not os.path.exists(human_file_path):
                    print(f"Skipping {human_file_path} because it doesn't exist")
                    continue
                df = pd.read_csv(human_file_path)
                invalid_ratio, invalid_ratio_train, invalid_ratio_test = get_metrics(df, isSplit, split_type, data_prefix, isLLM=False)
                if invalid_ratio is not None:
                    human_ratio_dict['invalid_ratio'].append(invalid_ratio)
                break
            else:
                file_path = os.path.join(output_dir, data_prefix, model, f'simulation-llm-{simulation_version}.csv')
                if not os.path.exists(file_path):
                    print(f"Skipping {file_path} because it doesn't exist")
                    continue
                if model != "gpt-4o-mini-2024-07-18":  # ft:Llama-3.1-8B-Instruct-SFT-20250711:group-5epochs, ft:Llama-3.1-8B-Instruct-SFT-20250710:round-5epochs, ft:Llama-3.1-8B-Instruct-SFT-20250710:topic-5epochs
                    continue
                print(f"Processing {file_path}")
                df = pd.read_csv(file_path)
                invalid_ratio, invalid_ratio_train, invalid_ratio_test = get_metrics(df, isSplit, split_type, data_prefix, isLLM=True)
                if model not in validation_per_model:
                    print(f"Adding {model} to validation_per_model")
                    validation_per_model[model] = {}
                    validation_per_model[model]['combined'] = []
                    validation_per_model[model]['train'] = []
                    validation_per_model[model]['test'] = []
                if invalid_ratio is not None:
                    validation_per_model[model]['combined'].append(invalid_ratio)
                if invalid_ratio_train is not None:
                    validation_per_model[model]['train'].append(invalid_ratio_train)
                if invalid_ratio_test is not None:
                    validation_per_model[model]['test'].append(invalid_ratio_test)
            
    if isLLM:
        for model in validation_per_model:
            metrics_df = metrics_df._append({
                'model': model,
                'combined': np.mean(validation_per_model[model]['train'] + validation_per_model[model]['test']) if validation_per_model[model]['train'] and validation_per_model[model]['test'] else 0.0,
                'train': np.mean(validation_per_model[model]['train']) if validation_per_model[model]['train'] else 0.0,
                'test': np.mean(validation_per_model[model]['test']) if validation_per_model[model]['test'] else 0.0
            }, ignore_index=True)
    else:
        metrics_df = metrics_df._append({
            'model': "human", 
            'combined': np.mean(human_ratio_dict['invalid_ratio']) if human_ratio_dict['invalid_ratio'] else 0.0
        }, ignore_index=True)

    if isLLM:
        metrics_path = os.path.join('..', '..', 'result', 'eval', 'validity', f'llm-{simulation_version}-{split_type}.csv')
    else:
        metrics_path = os.path.join('..', '..', 'result', 'eval', 'validity', f'human.csv')
    metrics_df.to_csv(metrics_path, index=False)

if __name__ == "__main__":
    data_dir = os.path.join('..', '..', 'result', 'simulation')
    # main(data_dir, get_metrics_only=True, isLLM=False, isSplit=True, split_type='round', simulation_version='v2')

    # for simulation_version in ["v0", "v1", "v2"]:
    #     for split_type in ["round", "topic", "group"]:
    main(data_dir, get_metrics_only=True, isLLM=False, isSplit=True, split_type="round", simulation_version="v2")
