import argparse
import datetime

import numpy as np
from datasets import load_from_disk
import json
from pathlib import Path


def parse_args():
    parser = argparse.ArgumentParser(description='Compare rewards between two datasets')
    parser.add_argument('--dataset1_path', type=str, required=True, help='Path to first processed dataset')
    parser.add_argument('--dataset2_path', type=str, required=True, help='Path to second processed dataset')
    parser.add_argument("--log_tag", type=str, default="")
    return parser.parse_args()


def load_and_validate_dataset(path):
    """Load dataset and validate that it has reward values."""
    dataset = load_from_disk(path)
    if 'golden_reward' not in dataset.features:
        raise ValueError(f"Dataset at {path} does not have reward values computed. "
                         "Please run compute_rewards.py first.")
    return dataset


def compute_statistics(rewards1, rewards2):
    """Compute comprehensive comparison statistics."""
    wins = (rewards1 > rewards2).sum()
    ties = (rewards1 == rewards2).sum()
    total = len(rewards1)

    stats = {
        "win_rate": float(wins / total),
        "tie_rate": float(ties / total),
        "loss_rate": float((total - wins - ties) / total),
        "total_comparisons": int(total),
        "wins": int(wins),
        "ties": int(ties),
        "losses": int(total - wins - ties),
        "dataset1_stats": {
            "mean": float(np.mean(rewards1)),
            "std": float(np.std(rewards1)),
            "min": float(np.min(rewards1)),
            "max": float(np.max(rewards1)),
            "median": float(np.median(rewards1))
        },
        "dataset2_stats": {
            "mean": float(np.mean(rewards2)),
            "std": float(np.std(rewards2)),
            "min": float(np.min(rewards2)),
            "max": float(np.max(rewards2)),
            "median": float(np.median(rewards2))
        }
    }

    return stats

def save_comparison_results(stats, args):
    # Append to a josnl file
    logging_results = {
        "tag": args.log_tag,
        "target_gen_path": args.dataset2_path,
        "win_rate": stats['loss_rate'] * 100,
        "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "base_dataset1_path": args.dataset1_path,
    }

    with open("outputs/golden_win_rates.jsonl", "a") as f:
        json.dump(logging_results, f)
        f.write("\n")

def main():
    args = parse_args()

    # Load datasets
    dataset1 = load_and_validate_dataset(args.dataset1_path)
    dataset2 = load_and_validate_dataset(args.dataset2_path)

    # Ensure datasets have the same length
    if len(dataset1) != len(dataset2):
        raise ValueError(f"Datasets have different lengths: {len(dataset1)} vs {len(dataset2)}")

    # Get reward arrays
    rewards1 = np.array(dataset1['golden_reward'])
    rewards2 = np.array(dataset2['golden_reward'])

    # Compute statistics
    stats = compute_statistics(rewards1, rewards2)

    summary = f"""{args.dataset2_path} wins: {stats['loss_rate']*100:.2f} % ({stats['losses']} times)
    """
    print(summary)

    # Save results
    save_comparison_results(stats, args)


if __name__ == "__main__":
    main()