import os
import json
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
from itertools import combinations

from snorkel.labeling import LFAnalysis
from snorkel.labeling.model import LabelModel, MajorityLabelVoter
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, classification_report
)

def load_data(file_path):
    """Load JSON data from the specified file path."""
    with open(file_path, 'r') as file:
        data = json.load(file)
    print(f"Loaded {len(data)} samples from {file_path}")
    return data

def prepare_data(data):
    """Prepare weak label matrix and ground truth labels."""
    weak_label_table = []
    gt_winners = []
    for index in tqdm(data.keys(), desc="Preparing data"):
        gt_winners.append(data[index]["winner_index"])
        weak_label_table.append(data[index]["weak_labels"])
    return np.array(weak_label_table), np.array(gt_winners)

def analyze_lfs(weak_label_matrix, gt_winners, data):
    """Analyze labeling functions and return a summary."""
    lf_analysis = LFAnalysis(L=weak_label_matrix).lf_summary(gt_winners)
    lf_analysis.index = data["0"]["judging_results"]["response1"].keys()
    return lf_analysis

def filter_lfs(lf_analysis, min_acc=0.6, min_coverage=0.3):
    """Filter labeling functions based on accuracy and coverage."""
    filtered_lfs = lf_analysis[
        (lf_analysis["Emp. Acc."] >= min_acc) & 
        (lf_analysis["Coverage"] >= min_coverage)
    ].index
    print(f"Filtered to {len(filtered_lfs)} LFs with accuracy >= {min_acc} and coverage >= {min_coverage}")
    return filtered_lfs

def evaluate_subset(lf_indices, weak_label_matrix, gt_winners):
    """Evaluate a subset of labeling functions."""
    selected_label_model = LabelModel(cardinality=2, verbose=False)
    selected_label_model.fit(
        L_train=weak_label_matrix[:, lf_indices],
        n_epochs=500,
        log_freq=100,
        seed=123,
        l2=1,
        optimizer="adam"
    )
    acc = selected_label_model.score(
        L=weak_label_matrix[:, lf_indices],
        Y=gt_winners,
        tie_break_policy="random"
    )["accuracy"]
    return (acc, len(lf_indices), lf_indices)

def train_label_model(weak_label_matrix, n_epochs=500, l2=1, seed=123):
    """Train the Snorkel LabelModel."""
    label_model = LabelModel(cardinality=2, verbose=True)
    label_model.fit(
        L_train=weak_label_matrix, 
        n_epochs=n_epochs, 
        log_freq=100, 
        seed=seed, 
        l2=l2, 
        optimizer="adam"
    )
    return label_model

def evaluate_model(model, weak_label_matrix, gt_winners, model_name):
    """Evaluate the model, print metrics, and save visualizations."""
    pseudo_labels = model.predict(weak_label_matrix, tie_break_policy="random")
    pseudo_probas = model.predict_proba(weak_label_matrix)[:, 1]
    metrics = {
        "accuracy": accuracy_score(gt_winners, pseudo_labels),
        "precision": precision_score(gt_winners, pseudo_labels),
        "recall": recall_score(gt_winners, pseudo_labels),
        "f1": f1_score(gt_winners, pseudo_labels),
        "roc_auc": roc_auc_score(gt_winners, pseudo_labels)
    }
    class_report = classification_report(gt_winners, pseudo_labels, output_dict=True)

    print(f"\n{model_name} Performance:")
    for metric, value in metrics.items():
        print(f"{metric.capitalize():<12}: {value:.3f}")
    print(classification_report(gt_winners, pseudo_labels))

    return pseudo_labels, pseudo_probas, metrics, class_report

def save_results(pseudo_labels, pseudo_probas, metrics, class_report, model_name, data, output_dir, dataset_name):
    """Save pseudo-labels and performance report to JSON files."""
    for idx, key in enumerate(data.keys()):
        data[key]["label_model_pseudo_label"] = int(pseudo_labels[idx])
        data[key]["label_model_pseudo_probability"] = float(pseudo_probas[idx])
        
    pseudo_label_file = os.path.join(output_dir, f"{dataset_name}_{model_name.lower().replace(' ', '_')}_pseudo_labels.json")
    with open(pseudo_label_file, 'w') as f:
        json.dump(data, f, indent=4)
    print(f"Saved pseudo-labels to {pseudo_label_file}")

    performance_report = {
        "metrics": metrics,
        "classification_report": class_report
    }
    report_file = os.path.join(output_dir, f"{dataset_name}_{model_name.lower().replace(' ', '_')}_performance.json")
    with open(report_file, 'w') as f:
        json.dump(performance_report, f, indent=4)
    print(f"Saved performance report to {report_file}")

def main():
    parser = argparse.ArgumentParser(description="Weak Supervision Label Aggregation")
    parser.add_argument("--dataset", choices=["judgeLM", "pandaLM", "shp", "rlhf", "prometheus"], required=True, help="Name of input dataset")
    parser.add_argument("--output_dir", default="label_model_results", help="Directory to save outputs")
    parser.add_argument("--n_epochs", type=int, default=500, help="Number of epochs for LabelModel")
    parser.add_argument("--l2", type=float, default=1.0, help="L2 regularization strength")
    parser.add_argument("--min_acc", type=float, default=0.65, help="Minimum accuracy for LF filtering")
    parser.add_argument("--min_coverage", type=float, default=0.3, help="Minimum coverage for LF filtering")
    args = parser.parse_args()

    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)

    # Load and prepare data
    input_file = os.path.join("program_outputs", f"{args.dataset}_outputs.json")
    data = load_data(input_file)
    weak_label_matrix, gt_winners = prepare_data(data)

    # Analyze labeling functions
    lf_analysis = analyze_lfs(weak_label_matrix, gt_winners, data)
    lf_analysis.to_csv(os.path.join(args.output_dir, f"{args.dataset}_lf_summary.csv"))

    # Filter LFs
    filtered_lf_set = set(filter_lfs(lf_analysis, args.min_acc, args.min_coverage))
    filtered_lf_idx = [i for i, name in enumerate(lf_analysis.index) if name in filtered_lf_set]
    filtered_weak_label_matrix = weak_label_matrix[:, filtered_lf_idx]

    # Parallel brute force search for best LF combination
    num_lfs = len(filtered_lf_idx)
    print(f"Generating combinations for {num_lfs} filtered LFs...")
    all_combinations = [np.array(combo) for r in range(3, num_lfs + 1) 
                       for combo in combinations(range(num_lfs), r)]
    print(f"Evaluating {len(all_combinations)} combinations in parallel...")

    acc_results = Parallel(n_jobs=-1)(
        delayed(evaluate_subset)(idx, filtered_weak_label_matrix, gt_winners) 
        for idx in tqdm(all_combinations, desc="Evaluating combinations")
    )

    # Find the best combination
    best_acc, best_n, best_indices = max(acc_results, key=lambda x: x[0])
    print(f"Best accuracy: {best_acc:.3f} with {best_n} LFs")
    
    # Map best_indices back to original LF names
    filtered_lf_names = [lf_analysis.index[i] for i in filtered_lf_idx]  # Names of filtered LFs
    best_lf_names = [filtered_lf_names[i] for i in best_indices]        # Names of best LFs
    print("\nLabeling Functions used in the best combination:")
    for i, lf_name in enumerate(best_lf_names, 1):
        print(f"{i}. {lf_name}")
        
    best_weak_label_matrix = filtered_weak_label_matrix[:, best_indices]

    # Train final model with best combination
    label_model = train_label_model(best_weak_label_matrix, args.n_epochs, args.l2)
    label_pseudo_labels, label_pseudo_probas, label_metrics, label_class_report = evaluate_model(
        label_model, best_weak_label_matrix, gt_winners, "Best Label Model"
    )
    save_results(label_pseudo_labels, label_pseudo_probas, label_metrics, label_class_report, 
                "Best Label Model", data, args.output_dir, args.dataset)

    # Majority Vote baseline with best combination
    majority_model = MajorityLabelVoter()
    majority_pseudo_labels, majority_pseudo_probas, majority_metrics, majority_class_report = evaluate_model(
        majority_model, best_weak_label_matrix, gt_winners, "Majority Vote"
    )
    save_results(majority_pseudo_labels, majority_pseudo_probas, majority_metrics, majority_class_report,
                "Majority Vote", data, args.output_dir, args.dataset)

if __name__ == "__main__":
    main()