import random
import itertools
import os
import re
from collections import defaultdict
import shutil
import pandas as pd
import typing

SEED = 42


check_collection_type = lambda dp, collection_type, breadth_set: dp in breadth_set if collection_type == "breadth" else dp not in breadth_set

def get_breadth_set():
    breadth_dir = os.path.join("../../data", "raw_data", "phase_2_breadth_topics")
    breadth_set = set()
    for root, dirs, files in os.walk(breadth_dir):
        for filename in files:
            if filename.endswith('.csv'):
                # Remove _0.0.1 before .csv if it exists
                base_name = filename.replace('_0.0.1', '').replace('.csv', '')
                breadth_set.add(base_name)
    print(f"Size of breadth set: {len(breadth_set)}")
    return breadth_set


def partition_by_experiment(directory_path):
    """
    Partition experiment data into training and test sets based on experiment.
    Randomly selects 1 non-augmented experiment for testing, remaining experiments are for training.
    Augmented files are always included in training set.
    
    Args:
        directory_path: Path to the directory containing experiment data
        
    Returns:
        train_files, test_file: Tuple containing the training and test files
    """
    # Set random seed for reproducibility
    random.seed(SEED)
    PERCENT_TEST = 0.2
    
    # Get all csv files and separate augmented from regular files
    all_csv_files = [f for f in os.listdir(directory_path) if f.endswith('.csv')]
    augmented_files = [f for f in all_csv_files if 'augmented' in f]
    regular_files = [f for f in all_csv_files if 'augmented' not in f]
    
    # Randomly select regular files for test set
    test_files = random.sample(regular_files, int(len(regular_files) * PERCENT_TEST))
    
    # Training files include augmented files and remaining regular files
    train_files = augmented_files + [f for f in regular_files if f not in test_files]

    return train_files, test_files


def partition_by_rounds(data_directory, collection_type: typing.Literal["depth", "breadth"] = "depth", model_name: typing.Optional[str] = None, output_data_path: str = "../../data/finetune_data/", training_method: typing.Literal["sft", "dpo"] = "sft"):  # round partition
    """
    Partition experiment data into training and test sets based on rounds.
    All rounds except the last one go to training set, last round is test set.
    Data without round information goes to training set.
    
    Args:
        data_directory: Path to the directory containing experiment data
    
    Returns:
        train_data, test_data: Tuple containing the partitioned datasets
    """
    breadth_set = get_breadth_set()
    error_count = 0
    train_files = []
    test_files = []
    train_dir = os.path.join(output_data_path, f"round_split_data_{collection_type}", "train")
    test_dir = os.path.join(output_data_path, f"round_split_data_{collection_type}", "test")
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)
    print(f"Total files in data directory: {len(os.listdir(data_directory))}")
    for file in os.listdir(data_directory):
        data_prefix = file[:-4]
        if not re.match(r'2025(03|04|05|06|07|08).*', data_prefix):
            continue
        if not check_collection_type(data_prefix, collection_type, breadth_set):
            continue
        if model_name is None:
            filepath = os.path.join('../../result/eval/validity', data_prefix, 'simulation-human.csv')
        else:
            filepath = os.path.join('../../result/eval/validity', data_prefix, model_name, 'simulation-llm-v2.csv')
        if not os.path.exists(filepath):
            error_count += 1
            continue
        
        original_data = pd.read_csv(filepath)
        
        if training_method == "dpo":
            train_data = original_data[(original_data['chat_round_order'] != 3) & (original_data['event_type'] != 'Post Opinion')].copy()
        else:
            train_data = original_data
        train_data.to_csv(os.path.join(train_dir, data_prefix + '.csv'), index=False)
        train_files.append(data_prefix + '.csv')
        
        if training_method == "dpo":
            test_data = original_data[(original_data['chat_round_order'] == 3) | (original_data['event_type'] == 'Post Opinion')].copy()
        else:
            test_data = original_data
        test_data.to_csv(os.path.join(test_dir, data_prefix + '.csv'), index=False)
        test_files.append(data_prefix + '.csv')
    print(f"[Round partition] Error count: {error_count}")
    
    return train_files, test_files

def partition_by_participants(data):
    """
    Partition experiment data into training and test sets based on participant pairs.
    Randomly selects 2 pairs for testing, remaining 4 pairs for training.
    
    Args:
        data: DataFrame containing experiment data with participant information
        
    Returns:
        train_data, test_data: Tuple containing the partitioned datasets
    """
    # Set random seed for reproducibility
    random.seed(SEED)
    
    # Get unique participants and create all possible pairs
    unique_participants = sorted(list(set(data['empirica_id'])))
    all_pairs = list(itertools.combinations(unique_participants, 2))
    
    # Randomly select 2 pairs for testing
    test_pairs = random.sample(all_pairs, 2)
    train_pairs = [pair for pair in all_pairs if pair not in test_pairs]
    
    # Create mask for test data
    test_mask = data.apply(lambda row: (row['sender_id'], row['recipient_id']) in test_pairs or 
                                     (row['recipient_id'], row['sender_id']) in test_pairs, axis=1)
    
    test_data = data[test_mask].copy()
    train_data = data[~test_mask].copy()
    
    return train_data, test_data

def get_unique_topics(data_directory):
    """
    Get unique topics from all csv files in the given directory.
    """
    # Sort files for deterministic order
    all_csv_files = sorted([f for f in os.listdir(data_directory) if f.endswith('.csv')])
    topics = [re.search(r'\d{8}_\d{6}_(.*)_.{26}', f).group(1).replace('_', ' ') for f in all_csv_files]
    topics = [re.sub(r' +', ' ', topic) for topic in topics]
    # Return sorted list for reproducible order
    return sorted(list(set(topics)))

def partition_by_different_topics(data_directory, test_topics, collection_type: typing.Literal["depth", "breadth"] = "depth", model_name: typing.Optional[str] = None, output_data_path: str = "../../data/finetune_data/"):  # topic partition
    """
    Partition experiment data into training and test sets based on topic.
    Randomly selects 2 topics for testing, remaining topics for training.
    """
    breadth_set = get_breadth_set()
    
    # Set random seed for reproducibility
    random.seed(SEED)
    
    # Get unique topics
    topics = get_unique_topics(data_directory)
    # Get all csv files in the directory that don't match 202503 or 202504
    all_csv_files = [f for f in os.listdir(data_directory) 
                    if f.endswith('.csv') and re.match(r'2025(03|04|05|06|07|08).*\.csv', f)]
    
    train_topics = [topic for topic in topics if topic not in test_topics]
    
    # Get files corresponding to train and test topics
    test_files = [f for f in all_csv_files if any(topic.replace(' ', '_') in f for topic in test_topics)]
    train_files = [f for f in all_csv_files if any(topic.replace(' ', '_') in f for topic in train_topics)]

    all_dfs = []

    # Process train files  
    for file in train_files:
        file_path = os.path.join(data_directory, file)
        df = pd.read_csv(file_path)
        # Add data_split column where text exists
        df.loc[df['text'].notna() & (df['text'] != ''), 'data_split'] = 'train'
        all_dfs.append(df)
    # Process test files
    for file in test_files:
        file_path = os.path.join(data_directory, file)
        df = pd.read_csv(file_path)
        # Add data_split column where text exists
        df.loc[df['text'].notna() & (df['text'] != ''), 'data_split'] = 'test'
        all_dfs.append(df)

    # Combine all dataframes and save to single CSV
    # combined_df = pd.concat(all_dfs, ignore_index=True)
    # combined_df.to_csv(os.path.join(output_data_path, 'partition.csv'), index=False)

    train_dir = os.path.join(output_data_path, f"topic_split_data_{collection_type}", "train")
    os.makedirs(train_dir, exist_ok=True)
    train_count = 0
    for f in train_files:
        # Remove .csv extension to get data prefix
        data_prefix = f[:-4]
        if not check_collection_type(data_prefix, collection_type, breadth_set):
            continue
        if model_name is None:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, 'simulation-human.csv')
        else:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, model_name, 'simulation-llm-v2.csv')
        if os.path.exists(sim_filepath):
            shutil.copyfile(sim_filepath, os.path.join(train_dir, f))
            train_count += 1
    print(f"[Topic partition] Saved {train_count} train files to {train_dir}")

    # Save test files
    test_dir = os.path.join(output_data_path, f"topic_split_data_{collection_type}", "test")
    os.makedirs(test_dir, exist_ok=True)
    test_count = 0
    for f in test_files:
        # Remove .csv extension to get data prefix
        data_prefix = f[:-4]
        if not check_collection_type(data_prefix, collection_type, breadth_set):
            continue
        if model_name is None:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, 'simulation-human.csv')
        else:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, model_name, 'simulation-llm-v2.csv')
        if os.path.exists(sim_filepath):
            test_count += 1
            shutil.copyfile(sim_filepath, os.path.join(test_dir, f))
    print(f"[Topic partition] Saved {test_count} test files to {test_dir}")
    
    # return train_files, test_files

def topic_split_breadth(data_directory, collection_type: typing.Literal["depth", "breadth"] = "breadth", model_name: typing.Optional[str] = None, output_data_path: str = "../../data/finetune_data/"):
    """
    Partition breadth topic data into training and test sets using random 72-28 split.
    
    Args:
        data_directory: Path to the directory containing experiment data
        collection_type: Type of collection ("depth" or "breadth"), defaults to "breadth"
        model_name: Optional model name for specific file paths
        output_data_path: Path to save the partitioned data
    """
    breadth_set = get_breadth_set()
    
    # Set random seed for reproducibility
    random.seed(SEED)
    
    # Get all csv files that match the pattern and are breadth topics
    all_csv_files = [f for f in os.listdir(data_directory) 
                    if f.endswith('.csv')]
    
    print(f"Total files in data directory: {len(all_csv_files)}")
    all_csv_files = [f for f in all_csv_files if re.match(r'2025(03|04|05|06|07|08).*\.csv', f)]
    # Filter for breadth topics only
    breadth_files = []
    for f in all_csv_files:
        # Remove .csv extension to get data prefix
        data_prefix = f[:-4]
        if check_collection_type(data_prefix, collection_type, breadth_set):
            breadth_files.append(f)
    
    if not breadth_files:
        print(f"[Random breadth split] No breadth files found in {data_directory}")
        return [], []
    
    print(f"[Random breadth split] Found {len(breadth_files)} breadth files")
    
    # Randomly shuffle the files for random split
    random.shuffle(breadth_files)
    
    # Calculate split sizes (72% train, 28% test)
    total_files = len(breadth_files)
    print(f"[Random breadth split] Total files: {total_files}")
    train_size = int(total_files * 0.72)
    
    # Split the files
    train_files = breadth_files[:train_size]
    test_files = breadth_files[train_size:]
    
    print(f"[Random breadth split] Train files: {len(train_files)}, Test files: {len(test_files)}")
    
    # Create train directory and copy files
    train_dir = os.path.join(output_data_path, f"topic_split_data_{collection_type}", "train")
    os.makedirs(train_dir, exist_ok=True)
    train_count = 0
    error_count = 0
    for f in train_files:
        # Remove .csv extension to get data prefix
        data_prefix = f[:-4]
        if model_name is None:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, 'simulation-human.csv')
        else:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, model_name, 'simulation-llm-v2.csv')
        if os.path.exists(sim_filepath):
            shutil.copyfile(sim_filepath, os.path.join(train_dir, f))
            train_count += 1
        else:
            error_count += 1
    print(f"[Random breadth split] Error count: {error_count}")
    print(f"[Random breadth split] Saved {train_count} train files to {train_dir}")

    # Create test directory and copy files  
    test_dir = os.path.join(output_data_path, f"topic_split_data_{collection_type}", "test")
    os.makedirs(test_dir, exist_ok=True)
    test_count = 0
    for f in test_files:
        # Remove .csv extension to get data prefix
        data_prefix = f[:-4]
        if model_name is None:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, 'simulation-human.csv')
        else:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, model_name, 'simulation-llm-v2.csv')
        if os.path.exists(sim_filepath):
            shutil.copyfile(sim_filepath, os.path.join(test_dir, f))
            test_count += 1
        else:
            error_count += 1
    print(f"[Random breadth split] Error count: {error_count}")
    print(f"[Random breadth split] Saved {test_count} test files to {test_dir}")
    
    return train_files, test_files

def partition_by_same_topic(data_directory, collection_type: typing.Literal["depth", "breadth"] = "depth", model_name: typing.Optional[str] = None, output_data_path: str = "../../data/finetune_data/"):  # group partition
    """
    Partition experiment data into training and test sets based on topic.
    
    Args:
        data_directory (str): Path to directory containing CSV files
        
    Returns:
        tuple: Lists of training and test files
        
    Note:
        Uses random seed for reproducibility
    """
    breadth_set = get_breadth_set()
    
    # Set random seed for reproducibility
    random.seed(SEED)
    
    # Get all csv files in the given directory and sort for deterministic order
    all_csv_files = sorted([f for f in os.listdir(data_directory) if f.endswith('.csv')
                           and re.match(r'2025(03|04|05|06|07|08).*\.csv', f)])
    topic_to_files = defaultdict(list)
    
    for f in all_csv_files:
        topic = re.search(r'\d{8}_\d{6}_(.*)_.{26}', f).group(1).replace('_', ' ')
        topic = re.sub(r' +', ' ', topic)
        topic_to_files[topic].append(f)
    
    # Only keep topics with at least 2 files for train/test split
    valid_topics = {t: files for t, files in topic_to_files.items() if len(files) >= 2}
    invalid_topics = {t: files for t, files in topic_to_files.items() if len(files) < 2}
    print(f"[Group partition] Invalid topics: {len(invalid_topics)}")
    print(f"[Group partition] Valid topics: {len(valid_topics)}")
    
    if not valid_topics:
        raise ValueError("No topics found with at least 2 files for train/test split")
    
    train_files = []
    test_files = []
    
    # Split files for each topic 80-20
    for topic, files in valid_topics.items():
        num_test = max(1, int(len(files) * 0.2))
        topic_test = random.sample(files, num_test)
        topic_train = [f for f in files if f not in topic_test]
        
        train_files.extend(topic_train)
        test_files.extend(topic_test)
    
    train_dir = os.path.join(output_data_path, f"group_split_data_{collection_type}", "train")
    os.makedirs(train_dir, exist_ok=True)
    count = 0
    for f in train_files:
        data_prefix = f[:-4]
        if not check_collection_type(data_prefix, collection_type, breadth_set):
            continue
        if model_name is None:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, 'simulation-human.csv')
        else:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, model_name, 'simulation-llm-v2.csv')
        if os.path.exists(sim_filepath):
            shutil.copyfile(sim_filepath, os.path.join(train_dir, f))
            count += 1
    print(f"[Group partition] Saved {count} train files to {train_dir}")

    # Save test files  
    test_dir = os.path.join(output_data_path, f"group_split_data_{collection_type}", "test")
    os.makedirs(test_dir, exist_ok=True)
    count = 0
    for f in test_files:
        # Remove .csv extension to get data prefix
        data_prefix = f[:-4]
        if not check_collection_type(data_prefix, collection_type, breadth_set):
            continue
        if model_name is None:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, 'simulation-human.csv')
        else:
            sim_filepath = os.path.join("../../result/eval/validity", data_prefix, model_name, 'simulation-llm-v2.csv')
        if os.path.exists(sim_filepath):
            shutil.copyfile(sim_filepath, os.path.join(test_dir, f))
            count += 1
    print(f"[Group partition] Saved {count} test files to {test_dir}")

def select_train_test_files(train_dir, test_dir):
    # Set random seed for reproducibility
    random.seed(SEED)
    
    num_train = 10
    num_test = 10

    train_files = sorted(os.listdir(train_dir))
    test_files = sorted(os.listdir(test_dir))

    # Get topics for each file
    train_topics = {}
    test_topics = {}
    for f in train_files:
        topic = re.search(r'\d{8}_\d{6}_(.*)_.{26}', f).group(1).replace('_', ' ')
        train_topics[f] = re.sub(r' +', ' ', topic)
    for f in test_files:
        topic = re.search(r'\d{8}_\d{6}_(.*)_.{26}', f).group(1).replace('_', ' ')
        test_topics[f] = re.sub(r' +', ' ', topic)

    # Select files with different topics
    selected_train = []
    seen_topics_train = set()
    selected_test = []
    seen_topics_test = set()
    for f in random.sample(train_files, len(train_files)):
        if train_topics[f] not in seen_topics_train and len(selected_train) < num_train:
            selected_train.append(f)
            seen_topics_train.add(train_topics[f])
    
    for f in random.sample(test_files, len(test_files)):
        if test_topics[f] not in seen_topics_test and len(selected_test) < num_test:
            selected_test.append(f)
            seen_topics_test.add(test_topics[f])

    train_files = selected_train
    test_files = selected_test

    return train_files, test_files


def generate_sft_data():
    print("[SFT data] Generating SFT data")
    output_data_path = os.path.join("/mnt/dv/wid/projects3/XXXX-3-XXXX-5-human-ai/mini-twitter-llm-agent-binary/data/sft_data")
    os.makedirs(output_data_path, exist_ok=True)
    for collection_type in ["depth", "breadth"]:
        if collection_type == "depth":
            partition_by_different_topics("../../data/processed_data", ["The US deficit increased after President Obama was elected", "Regular fasting will improve your health"], collection_type, output_data_path=output_data_path)
        else:
            topic_split_breadth("../../data/processed_data", collection_type, output_data_path=output_data_path)
        partition_by_same_topic("../../data/processed_data", collection_type, output_data_path=output_data_path)
        train_files, test_files = partition_by_rounds("../../data/processed_data", collection_type, output_data_path=output_data_path, training_method="sft")
        print(f"[Round partition] Number of train files: {len(train_files)}")
        print(f"[Round partition] Number of test files: {len(test_files)}")


def generate_dpo_data():
    print("[DPO data] Generating DPO data")
    models = ["gpt-4o-mini-2024-07-18", "gpt-4.1-nano-2025-04-14", "Llama-3.1-8B-Instruct", None]
    for model in models:
        print(f"[DPO data] Processing {model}")
        output_data_path = os.path.join("/mnt/dv/wid/projects3/XXXX-3-XXXX-5-human-ai/mini-twitter-llm-agent-binary/data/dpo_data", model if model is not None else "human")
        os.makedirs(output_data_path, exist_ok=True)
        for collection_type in ["depth", "breadth"]:
            print(f"[DPO data] Processing {model} {collection_type}")
            if collection_type == "depth":
                partition_by_different_topics("../../data/processed_data", ["The US deficit increased after President Obama was elected", "Regular fasting will improve your health"], collection_type, model_name=model, output_data_path=output_data_path)  # topic partition
            else:
                topic_split_breadth("../../data/processed_data", collection_type, model_name=model, output_data_path=output_data_path)
            partition_by_same_topic("../../data/processed_data", collection_type, model_name=model, output_data_path=output_data_path)  # group partition
            train_files, test_files = partition_by_rounds("../../data/processed_data", collection_type, model_name=model, output_data_path=output_data_path, training_method="dpo")  # round partition
            print(f"[Round partition] Number of train files: {len(train_files)}")
            print(f"[Round partition] Number of test files: {len(test_files)}")


if __name__ == "__main__":
    # generate_sft_data()
    generate_dpo_data()
