from data.cv.oversight_test_data import get_cv_oversight_dataset
from data.cv_gender.oversight_test_data import get_cv_gender_oversight_dataset
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix
from scipy.stats import pearsonr
import wandb

def get_data(dataset_name):
    if dataset_name == "cv" or dataset_name == "toy_cv":
        return get_cv_oversight_dataset()
    elif dataset_name == "toy_cv_gender" or dataset_name == "toy_cv_gender_correlations":
        return get_cv_gender_oversight_dataset()
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

def log_overseer_stats(queries, answers, is_sus_labels, cont_labels, is_sus, rewards, wandb_log_func, wandb_category_str, overseer):
    # Convert lists to numpy arrays for easier manipulation
    is_sus = np.array(is_sus)
    is_sus_labels = np.array(is_sus_labels)
    rewards = np.array(rewards)
    cont_labels = np.array(cont_labels)

    # Dataset analysis
    num_positive = np.sum(is_sus_labels == 1)
    num_negative = np.sum(is_sus_labels == 0)
    print("Dataset Analysis:")
    print(f"Number of positive labels: {num_positive}")
    print(f"Number of negative labels: {num_negative}")
    print(f"Total samples: {len(is_sus_labels)}")
    print()

    # Calculate overall accuracy
    accuracy = accuracy_score(is_sus_labels, is_sus)
    print(f"Overall Accuracy: {accuracy:.4f}")

    # Calculate accuracy for each class (0 and 1)
    cm = confusion_matrix(is_sus_labels, is_sus)
    accuracy_class_0 = cm[0, 0] / (cm[0, 0] + cm[0, 1])
    accuracy_class_1 = cm[1, 1] / (cm[1, 0] + cm[1, 1])
    print(f"Accuracy for class 0: {accuracy_class_0:.4f}")
    print(f"Accuracy for class 1: {accuracy_class_1:.4f}")

    # Calculate correlation between rewards and cont_labels
    correlation, p_value = pearsonr(rewards, cont_labels)
    print(f"Correlation between rewards and cont_labels: {correlation:.4f}")
    print(f"P-value: {p_value:.4f}")

    # Reward statistics for each label
    rewards_label_0 = rewards[is_sus_labels == 0]
    rewards_label_1 = rewards[is_sus_labels == 1]

    print("\nReward statistics for label 0:")
    print(f"Mean: {np.mean(rewards_label_0):.4f}")
    print(f"Median: {np.median(rewards_label_0):.4f}")
    print(f"Standard Deviation: {np.std(rewards_label_0):.4f}")
    print(f"Min: {np.min(rewards_label_0):.4f}")
    print(f"Max: {np.max(rewards_label_0):.4f}")

    print("\nReward statistics for label 1:")
    print(f"Mean: {np.mean(rewards_label_1):.4f}")
    print(f"Median: {np.median(rewards_label_1):.4f}")
    print(f"Standard Deviation: {np.std(rewards_label_1):.4f}")
    print(f"Min: {np.min(rewards_label_1):.4f}")
    print(f"Max: {np.max(rewards_label_1):.4f}")

    # Calculate accuracy for each sample
    sample_accuracy = (is_sus == is_sus_labels)

    # Create a DataFrame to display results
    results_df = pd.DataFrame({
        'Query_to_overseer': queries,
        'Text_to_oversee': answers,
        # 'Continuous Label': cont_labels,
        'Is Suspicious (Actual)': is_sus_labels.astype(bool),
        'Is Suspicious (Predicted)': is_sus,
        'Accurate?': sample_accuracy,
        'Overseer Reward': rewards
    })

    # Display the results
    print("\nResults:")
    print(results_df.to_string(index=False))

    # Additional statistics
    print("\nAdditional Statistics:")
    print(f"Mean Reward: {np.mean(rewards):.4f}")
    print(f"Median Reward: {np.median(rewards):.4f}")
    print(f"Standard Deviation of Rewards: {np.std(rewards):.4f}")
    print(f"Minimum Reward: {np.min(rewards):.4f}")
    print(f"Maximum Reward: {np.max(rewards):.4f}")

    
    '''
    Wandb logging
    '''
    if wandb_log_func:
        
        table = wandb.Table(dataframe=results_df)
        
        log_dict = {
            "accuracy": accuracy,
            "accuracy_class_0": accuracy_class_0,
            "accuracy_class_1": accuracy_class_1,
            "correlation": correlation,
            "p_value": p_value,
            "mean_reward": np.mean(rewards),
            "median_reward": np.median(rewards),
            "std_reward": np.std(rewards),
            "min_reward": np.min(rewards),
            "max_reward": np.max(rewards),
            "dataset_size": len(is_sus_labels),
            "dataset_num_positive": num_positive,
            "dataset_num_negative": num_negative,
            "class_0_reward_mean": np.mean(rewards_label_0),
            "class_1_reward_mean": np.mean(rewards_label_1),
            "class_0_reward_std": np.std(rewards_label_0),
            "class_1_reward_std": np.std(rewards_label_1),
            "table": table,
            "sus_threshold": overseer.sus_threshold
        }
        
        # prepend wandb_category_str to keys
        log_dict = {f"{wandb_category_str}/{k}": v for k, v in log_dict.items()}
        
        wandb_log_func(log_dict)
        
    results_summary = {
        'overall_accuracy': accuracy,
        "accuracy_class_0": accuracy_class_0,
        "accuracy_class_1": accuracy_class_1,
        "class_0_rew_mean": np.mean(rewards_label_0),
        "class_1_rew_mean": np.mean(rewards_label_1)
        }
    return results_summary
    
        

def eval_overseer(overseer, dataset_name, max_batch_size, wandb_category_str, wandb_log_func=None):
    
    dataset = get_data(dataset_name)
    
    # prep inputs
    queries = [0] * len(dataset)
    answers = [item['text'] for item in dataset]
    is_sus_labels = [item['bin_label'] for item in dataset]
    cont_labels = [item['cont_label'] for item in dataset]
    
    # get outputs from overseer
    is_sus, rewards, infos = overseer.flag_sus_examples(queries, answers, is_sus_labels, max_batch_size)
    queries = infos["input_texts"]
    
    results_summary = log_overseer_stats(queries, answers, is_sus_labels, cont_labels, is_sus, rewards, wandb_log_func, wandb_category_str, overseer)
    
    return results_summary



if __name__ == "__main__":
    
    from hparams.utils import RLExpRunType, get_run_type_hparams
    from stego_trainer import StegoTrainer
    # TODO: import may not work
    
    run_type = RLExpRunType.DEBUG
    
    hparams = get_run_type_hparams(run_type)
    
    stego_trainer = StegoTrainer(hparams)
    
    overseer = stego_trainer.overseer
    dataset_name = "cv"
    max_batch_size = hparams['enc']['mini_batch_size']
    wandb_category_str = "main"
    wandb_log_func = stego_trainer.do_wandb_log
    
    eval_overseer(overseer, dataset_name, max_batch_size, wandb_category_str, wandb_log_func)