import pandas as pd
import datasets as dt
import itertools
from tqdm import tqdm

def find_closest_pair_indices(numbers):
    """
    Finds the two indices in the list 'numbers' where the corresponding values are the closest.

    Parameters:
    numbers (List[float or int]): The list of numbers to evaluate.

    Returns:
    Tuple[int, int]: A tuple containing the two indices with the smallest difference.
                     Returns None if the list has fewer than two elements.
    """
    if len(numbers) < 2:
        print("List must contain at least two elements.")
        return None

    # Create a list of tuples where each tuple is (value, original_index)
    indexed_numbers = list(enumerate(numbers))
    
    # Sort the list based on the values
    sorted_numbers = sorted(indexed_numbers, key=lambda x: x[1])
    
    # Initialize minimum difference and the pair of indices
    min_diff = float('inf')
    closest_pair = (None, None)
    
    # Iterate through the sorted list to find the minimal difference
    for i in range(len(sorted_numbers) - 1):
        current = sorted_numbers[i]
        next_item = sorted_numbers[i + 1]
        diff = abs(next_item[1] - current[1])
        
        if diff < min_diff:
            min_diff = diff
            closest_pair = (current[0], next_item[0])
    
    return closest_pair

def find_optimal_pairs(p):
    """
    Finds two optimal pairs of rows in DataFrame 'p' based on specified criteria.

    Parameters:
    p (pd.DataFrame): DataFrame containing 'chosen_score' and 'rejected_score' columns.

    Returns:
    dict: A dictionary containing the two optimal pairs.
    """
    # Generate all possible unique pairs of row indices
    row_indices = p.index.tolist()
    all_pairs = list(itertools.combinations(row_indices, 2))
    
    # Initialize lists to store pair information
    pairs_info = []
    
    # Calculate (chosen_score - rejected_score) for each row
    p['score_diff'] = p['chosen_score'] - p['rejected_score']
    
    # Iterate through all pairs and compute required metrics
    for pair in all_pairs:
        i, j = pair
        chosen_diff = abs(p.at[i, 'chosen_score'] - p.at[j, 'chosen_score'])
        rejected_diff = abs(p.at[i, 'rejected_score'] - p.at[j, 'rejected_score'])
        score_diff_diff = abs(p.at[i, 'score_diff'] - p.at[j, 'score_diff'])
        
        pairs_info.append({
            'pair': pair,
            'chosen_diff': chosen_diff,
            'rejected_diff': rejected_diff,
            'score_diff_diff': score_diff_diff
        })
    
    # Convert to DataFrame for easier manipulation
    pairs_df = pd.DataFrame(pairs_info)
    
    # --- First Pair: Maximize chosen_diff, Minimize score_diff_diff ---
    # Find the maximum chosen_diff
    max_chosen_diff = pairs_df['chosen_diff'].max()
    # Filter pairs with the maximum chosen_diff
    chosen_pairs = pairs_df[pairs_df['chosen_diff'] == max_chosen_diff]
    # Among these, find the pair(s) with the minimal score_diff_diff
    min_score_diff_diff_chosen = chosen_pairs['score_diff_diff'].min()
    optimal_chosen_pair = chosen_pairs[chosen_pairs['score_diff_diff'] == min_score_diff_diff_chosen].iloc[0]['pair']
    
    # --- Second Pair: Maximize rejected_diff, Minimize score_diff_diff ---
    # Find the maximum rejected_diff
    max_rejected_diff = pairs_df['rejected_diff'].max()
    # Filter pairs with the maximum rejected_diff
    rejected_pairs = pairs_df[pairs_df['rejected_diff'] == max_rejected_diff]
    # Among these, find the pair(s) with the minimal score_diff_diff
    min_score_diff_diff_rejected = rejected_pairs['score_diff_diff'].min()
    optimal_rejected_pair = rejected_pairs[rejected_pairs['score_diff_diff'] == min_score_diff_diff_rejected].iloc[0]['pair']
    
    # Prepare the result
    result = {
        'optimal_chosen_pair': optimal_chosen_pair,
        'optimal_rejected_pair': optimal_rejected_pair
    }
    
    return result

def find_optimal_pairs_detailed(p):
    """
    Finds two optimal pairs of rows in DataFrame 'p' based on specified criteria.
    For each pair, specifies which row has the higher chosen_score and rejected_score.

    Parameters:
    p (pd.DataFrame): DataFrame containing 'chosen_score' and 'rejected_score' columns.

    Returns:
    dict: A dictionary containing detailed information about the two optimal pairs.
    """
    # Generate all possible unique pairs of row indices
    row_indices = p.index.tolist()
    all_pairs = list(itertools.combinations(row_indices, 2))
    
    # Initialize list to store pair information
    pairs_info = []
    
    # Calculate (chosen_score - rejected_score) for each row
    p = p.copy()  # To avoid SettingWithCopyWarning
    p['score_diff'] = p['chosen_score'] - p['rejected_score']
    
    # Iterate through all pairs and compute required metrics
    for pair in all_pairs:
        i, j = pair
        chosen_diff = abs(p.at[i, 'chosen_score'] - p.at[j, 'chosen_score'])
        rejected_diff = abs(p.at[i, 'rejected_score'] - p.at[j, 'rejected_score'])
        score_diff_diff = abs(p.at[i, 'score_diff'] - p.at[j, 'score_diff'])
        
        pairs_info.append({
            'pair': pair,
            'chosen_diff': chosen_diff,
            'rejected_diff': rejected_diff,
            'score_diff_diff': score_diff_diff
        })
    
    # Convert to DataFrame for easier manipulation
    pairs_df = pd.DataFrame(pairs_info)
    
    # --- First Pair: Maximize chosen_diff, Minimize score_diff_diff ---
    # Find the maximum chosen_diff
    max_chosen_diff = pairs_df['chosen_diff'].max()
    # Filter pairs with the maximum chosen_diff
    chosen_pairs = pairs_df[pairs_df['chosen_diff'] == max_chosen_diff]
    # Among these, find the pair(s) with the minimal score_diff_diff
    min_score_diff_diff_chosen = chosen_pairs['score_diff_diff'].min()
    optimal_chosen_pairs = chosen_pairs[chosen_pairs['score_diff_diff'] == min_score_diff_diff_chosen]['pair'].tolist()
    # Select the first optimal chosen pair
    optimal_chosen_pair = optimal_chosen_pairs[0]
    
    # --- Second Pair: Maximize rejected_diff, Minimize score_diff_diff ---
    # Find the maximum rejected_diff
    max_rejected_diff = pairs_df['rejected_diff'].max()
    # Filter pairs with the maximum rejected_diff
    rejected_pairs = pairs_df[pairs_df['rejected_diff'] == max_rejected_diff]
    # Among these, find the pair(s) with the minimal score_diff_diff
    min_score_diff_diff_rejected = rejected_pairs['score_diff_diff'].min()
    optimal_rejected_pairs = rejected_pairs[rejected_pairs['score_diff_diff'] == min_score_diff_diff_rejected]['pair'].tolist()
    # Select the first optimal rejected pair
    optimal_rejected_pair = optimal_rejected_pairs[0]
    
    # Identify which row has the higher chosen_score in the first pair
    chosen_row1, chosen_row2 = optimal_chosen_pair
    if p.at[chosen_row1, 'chosen_score'] > p.at[chosen_row2, 'chosen_score']:
        higher_chosen = chosen_row1
        lower_chosen = chosen_row2
    else:
        higher_chosen = chosen_row2
        lower_chosen = chosen_row1
    
    # Identify which row has the higher rejected_score in the second pair
    rejected_row1, rejected_row2 = optimal_rejected_pair
    if p.at[rejected_row1, 'rejected_score'] > p.at[rejected_row2, 'rejected_score']:
        higher_rejected = rejected_row1
        lower_rejected = rejected_row2
    else:
        higher_rejected = rejected_row2
        lower_rejected = rejected_row1
    
    # Prepare the detailed result
    result = {
        'optimal_chosen_pair_high_chosen_score': higher_chosen,
        'optimal_chosen_pair_low_chosen_score': lower_chosen,
        'optimal_rejected_pair_high_rejected_score': higher_rejected,
        'optimal_rejected_pair_low_rejected_score': lower_rejected
    }
    
    return result

ds = dt.load_from_disk("generated_data/oasst2_en_dpo_w_reward")
ds = ds.map(lambda x: {"diff": x['chosen_score'] - x['rejected_score']})
df = ds.to_pandas()

dfg = df.groupby("prompt")

max_chosen_gap_highc = []
max_chosen_gap_lowc = []
max_rejected_gap_highr = []
max_rejected_gap_lowr = []

for _, p in tqdm(dfg):
    # Find the detailed optimal pairs
    if len(p) < 2:
        max_chosen_gap_highc.append(p.iloc[0])
        max_chosen_gap_lowc.append(p.iloc[0])
        max_rejected_gap_highr.append(p.iloc[0])
        max_rejected_gap_lowr.append(p.iloc[0])
    else:
        optimal_pairs_detailed = find_optimal_pairs_detailed(p)

        # Extract indices
        chosen_high = optimal_pairs_detailed['optimal_chosen_pair_high_chosen_score']
        chosen_low = optimal_pairs_detailed['optimal_chosen_pair_low_chosen_score']
        rejected_high = optimal_pairs_detailed['optimal_rejected_pair_high_rejected_score']
        rejected_low = optimal_pairs_detailed['optimal_rejected_pair_low_rejected_score']

        # Retrieve the row data for the chosen pair
        chosen_pair_high_row = p.loc[chosen_high]
        chosen_pair_low_row = p.loc[chosen_low]

        # Retrieve the row data for the rejected pair
        rejected_pair_high_row = p.loc[rejected_high]
        rejected_pair_low_row = p.loc[rejected_low]
        
        max_chosen_gap_highc.append(chosen_pair_high_row)
        max_chosen_gap_lowc.append(chosen_pair_low_row)
        max_rejected_gap_highr.append(rejected_pair_high_row)
        max_rejected_gap_lowr.append(rejected_pair_low_row)

df_c_high = pd.concat(max_chosen_gap_highc, axis=1).T.reset_index(drop=True)
df_c_low = pd.concat(max_chosen_gap_lowc, axis=1).T.reset_index(drop=True)
df_r_high = pd.concat(max_rejected_gap_highr, axis=1).T.reset_index(drop=True)
df_r_low = pd.concat(max_rejected_gap_lowr, axis=1).T.reset_index(drop=True)

df_c_high.to_parquet("oasst2_chosen_high.parquet")
df_c_low.to_parquet("oasst2_chosen_low.parquet")
df_r_high.to_parquet("oasst2_rejected_high.parquet")
df_r_low.to_parquet("oasst2_rejected_low.parquet")