import os
import argparse
import torch
from data import LoadDataset, LoadTrainModel
from data.clustering import (
    equal_clustering, kmeans_clustering,
    grad_kmeans_clustering, repr_kmeans_clustering
)
import time
import pandas as pd

def main():
    parser = argparse.ArgumentParser(description="Compute clusters for data attribution experiments.")
    parser.add_argument('-d', '--dataset_name', type=str, default='heloc', help="Name of the dataset")
    parser.add_argument('-m', '--model_name', type=str, default='lr', help="Name of the model")
    parser.add_argument('-g', '--grouping', type=str, default='equal', choices=['equal', 'kmeans', 'repr_kmeans', 'grad_kmeans'], help="Grouping method")
    parser.add_argument('-n', '--n_trials', type=int, default=5, help="Number of trials")
    parser.add_argument('-c', '--cuda', action='store_true', help="Use CUDA if available")

    # Start global timer
    original_start_time = time.time()

    # Parse arguments
    args = parser.parse_args()
    dataset_name, model_name = args.dataset_name, args.model_name
    grouping, n_trials = args.grouping, args.n_trials
    device = 'cuda' if args.cuda and torch.cuda.is_available() else 'cpu'

    # Get dataset, train_model, and property_fn
    trainset = LoadDataset(dataset_name, train=True)
    train_model, train_params = LoadTrainModel(dataset_name, model_name)

    if dataset_name in ['heloc', 'heloc_small', 'heloc_noisy', 'heloc_wd']:
        group_sizes = [1, 2, 4, 8, 16, 32, 64] if dataset_name == 'heloc_small' else [1, 2, 4, 8, 16, 32, 64, 128, 256]
    elif dataset_name in ['mnist', 'mnist_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
    elif dataset_name in ['cifar10', 'cifar10_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
    elif dataset_name in ['qnli', 'qnli_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

    # Save property_vals, num_datapoints_removed and plot
    save_dir = f'experiments/results/clusters/{dataset_name}_{model_name}/{grouping}/'
    os.makedirs(save_dir, exist_ok=True)

    # Ablate over group sizes and plot
    mean_grouping_times = []
    print("Group Sizes:", group_sizes)
    for group_size in group_sizes:
        grouping_times = []
        for trial in range(1, n_trials+1):
            print(f"\nTrial {trial}")
            print(f"\nGroup Size: {group_size}")
            torch.cuda.empty_cache()
            # Get cluster_ids
            if grouping == 'equal':
                start_time = time.time()
                group_ids = equal_clustering(trainset, group_size=group_size)
            elif grouping == 'kmeans':
                num_clusters = len(trainset) // group_size
                start_time = time.time()
                group_ids = kmeans_clustering(trainset, num_clusters=num_clusters, device=device)
            elif grouping == 'grad_kmeans':
                num_clusters = len(trainset) // group_size
                if dataset_name in ['qnli', 'qnli_noisy']:
                    model, model_dir = train_model(trainset, device=device, verbose=True, use_model_cache=True, return_model_dir=True)
                    start_time = time.time()
                    group_ids = grad_kmeans_clustering(trainset, num_clusters, model, train_params['loss_fn'], device=device, model_dir=model_dir)
                else:
                    model = train_model(trainset, device=device, verbose=True, use_model_cache=True)
                    start_time = time.time()
                    group_ids = grad_kmeans_clustering(trainset, num_clusters, model, train_params['loss_fn'], device=device)
            elif grouping == 'repr_kmeans':
                num_clusters = len(trainset) // group_size
                model = train_model(trainset, device=device, verbose=True, use_model_cache=True)
                start_time = time.time()
                group_ids = repr_kmeans_clustering(trainset, num_clusters, model, device=device)
            grouping_time = time.time() - start_time
            grouping_times.append(grouping_time)
            print(f"Grouping Time: {grouping_time:.3f}s")

            print("First 100 group ID counts:")
            print(group_ids.bincount()[:100])

        # Save group IDs for last trial
        save_path = os.path.join(save_dir, f'group_ids_size_{group_size}.pt')
        torch.save(group_ids, save_path)
        print(f"Saved group scores to {save_path}")

        # Save grouping and attribution times
        df = pd.DataFrame({'grouping_times': grouping_times})
        df.to_csv(os.path.join(save_dir, f'times_size_{group_size}.csv'), index=False)

        # Print mean time
        mean_grouping_time = sum(grouping_times) / len(grouping_times)
        mean_grouping_times.append(mean_grouping_time)
        print(f"Mean Grouping Time: {mean_grouping_time:.3f}s")

    # Save mean grouping and attribution times
    df = pd.DataFrame({'group_size': group_sizes, 'mean_grouping_time': mean_grouping_times})
    df.to_csv(os.path.join(save_dir, 'mean_times.csv'), index=False)

    # Print total time
    total_time = time.time() - original_start_time
    print(f"\nTotal Experiment Time: {total_time:.3f}s")

if __name__ == "__main__":
    main()
