import sys
import tqdm
import json
import copy
import pickle
import argparse
import numpy as np
import cvxpy as cp
import scipy as sp
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import combinations
from sklearn.model_selection import train_test_split

# Add project src directory to path
sys.path.append('../src')

# Custom utilities
from data_tools import (
    load_feedbackqa, load_helpsteer2, load_ultrafeedback, collect_judge_outputs, 
    is_valid_score
)
from pgm_tools import learn_structure, get_weights, majority_vote, uws_aggregate
from eval_tools import compute_mae

def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="Run greedy feature selection for specified dataset.")
    parser.add_argument('--dataset', type=str, choices=['feedbackqa', 'helpsteer2', 'ultrafeedback'],
                        required=True, help='Dataset to process (feedbackqa, helpsteer2, or ultrafeedback)')
    return parser.parse_args()

def load_and_prepare_data(dataset_name):
    """Load and prepare dataset with judge outputs."""
    if dataset_name == "feedbackqa":
        judge_output_dir = "../judge_outputs/feedbackqa"
        data_path = "../data/feedbackqa/feedback_train.json"
        ratings = load_feedbackqa(data_path)
        ratings = ratings.drop(index=[3186]).reset_index(drop=True)
        df = collect_judge_outputs(judge_output_dir)
        df = df.drop(index=[3186]).reset_index(drop=True)
    elif dataset_name == "helpsteer2":
        judge_output_dir = "../judge_outputs/helpsteer2"
        data_path = "../data/helpsteer2/helpsteer2_valid.json"
        ratings = load_helpsteer2(data_path)
        df = collect_judge_outputs(judge_output_dir)
    elif dataset_name == "ultrafeedback":
        judge_output_dir = "../judge_outputs/ultrafeedback_sampled"
        data_path = "../data/ultrafeedback_sampled.csv"
        ratings = load_ultrafeedback(data_path)
        ratings = ratings.drop(index=[1098, 1164]).reset_index(drop=True)
        df = collect_judge_outputs(judge_output_dir)
        df = df.drop(index=[1098, 1164]).reset_index(drop=True)
    
    program_judges = pd.read_csv(f"../automated_program_synthesis/program_judge_outputs/{dataset_name}_program_judge_results.csv")
    program_judges_dropped = program_judges.drop(columns=["question", "answer"])
    human_eval = ratings['score']
    
    # Rescale program judges using min-max scaling
    min_rating, max_rating = ratings['score'].min(), ratings['score'].max()
    for col in program_judges_dropped.columns:
        if col != 'id':
            min_val = program_judges_dropped[col].min()
            max_val = program_judges_dropped[col].max()
            if max_val > min_val:
                program_judges_dropped[col] = (program_judges_dropped[col] - min_val) / (max_val - min_val) * (max_rating - min_rating) + min_rating
            else:
                program_judges_dropped[col] = (max_rating + min_rating) / 2
    
    # Add rescaled program judges to dataframe
    for col in program_judges_dropped.columns:
        if col != 'id':
            df[col] = program_judges_dropped[col].values
    
    # Filter invalid scores
    mask = df.map(lambda x: is_valid_score(x, min_rating=min_rating, max_rating=max_rating)).all(axis=1)
    return df[mask].astype(float), human_eval[mask]

def greedy_feature_selection_multi(df, human_eval, gamma=1, threshold=1, start_k=3, val_size=0.1, random_state=42):
    """Greedy forward feature selection with multiple aggregation methods."""
    # Split into train/validation
    df_train, df_val, gt_train, gt_val = train_test_split(
        df, human_eval, test_size=val_size, random_state=random_state
    )
    n_features = df.shape[1]
    
    def compute_mae_for_features(df_eval, gt_eval, features, method):
        """Compute MAE for given features using specified method."""
        df_subset = df_eval.iloc[:, features]
        if method == 'majority_vote':
            pred = majority_vote(df_subset)
        elif method == 'simple_average':
            pred = df_subset.mean(axis=1).to_numpy()
        elif method == 'uws':
            pred = uws_aggregate(df_subset)
        elif method == 'caresl':
            corr_matrix = df_train.iloc[:, features].corr()
            corr_matrix = corr_matrix.replace([np.inf, -np.inf], np.nan).fillna(0.0)
            np.fill_diagonal(corr_matrix.values, 1.0)
            S_est, L_est = learn_structure(corr_matrix, gamma=gamma)
            weights = get_weights(L_est, threshold=threshold)
            pred = sum(weights[i] * df_subset.iloc[:, i].to_numpy() for i in range(len(weights))) / sum(weights)
        return compute_mae(pred, gt_eval)
    
    methods = ['caresl', 'majority_vote', 'simple_average', 'uws']
    results = {}
    
    for method in methods:
        print(f"Processing {method}...")
        # Step 1: Choose best start_k features
        best_init = None
        best_init_val_mae = float("inf")
        for comb in combinations(range(n_features), start_k):
            mae_val = compute_mae_for_features(df_val, gt_val, list(comb), method)
            if mae_val < best_init_val_mae:
                best_init_val_mae = mae_val
                best_init = list(comb)
        
        selected_features = best_init.copy()
        remaining_features = [f for f in range(n_features) if f not in selected_features]
        maes = [compute_mae_for_features(df, human_eval, selected_features, method)]
        selected_counts = [len(selected_features)]
        feature_sequence = [selected_features.copy()]
        
        # Step 2: Greedy expansion
        while remaining_features:
            best_feature = None
            best_val_mae = float("inf")
            for feat in remaining_features:
                trial_features = selected_features + [feat]
                mae_val = compute_mae_for_features(df_val, gt_val, trial_features, method)
                if mae_val < best_val_mae:
                    best_val_mae = mae_val
                    best_feature = feat
            selected_features.append(best_feature)
            remaining_features.remove(best_feature)
            final_mae = compute_mae_for_features(df, human_eval, selected_features, method)
            maes.append(final_mae)
            selected_counts.append(len(selected_features))
            feature_sequence.append(selected_features.copy())
        
        results[method] = {
            'selected_counts': selected_counts,
            'maes': maes,
            'final_judges': selected_features,
            'judge_sequence': feature_sequence
        }
    
    return results

def main():
    args = parse_args()
    dataset_name = args.dataset
    threshold = 1
    gamma = 1
    
    # Load and prepare data
    filter_df, ground_truth = load_and_prepare_data(dataset_name)
    
    # Run greedy feature selection
    results = greedy_feature_selection_multi(filter_df, ground_truth, gamma=gamma, threshold=threshold, start_k=3)
    
    # Plot results
    """
    plt.figure(figsize=(6, 5))
    colors = ['blue', 'red', 'green', 'orange']
    method_name = ["CARESL", "MV", "AVG", "UWS"]
    
    for i, (method, data) in enumerate(results.items()):
        plt.plot(data['selected_counts'][:20], data['maes'][:20], 
                color=colors[i], linestyle='-', 
                marker='o', markersize=5, label=method_name[i], linewidth=1)
    
    plt.xticks(np.arange(3, 23), np.arange(3, 23))
    plt.xlabel('Number of Selected Judge', fontsize=12)
    plt.ylabel('MAE', fontsize=12)
    plt.title(f'Greedy Forward Selection - {dataset_name}', fontsize=14)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    """
    
    # Save results
    with open(f"{dataset_name}_greedy_selection_with_new_judges.pkl", "wb") as f:
        pickle.dump(results, f)

if __name__ == "__main__":
    main()