import argparse
import torch
from data import LoadDataset, LoadTrainModel
from attributors.properties import CrossEntropyProperty, AccuracyProperty
import time

properties_dict = {
    'cross_entropy': CrossEntropyProperty,
    'accuracy': AccuracyProperty,
}

if __name__ == '__main__':
    # Implement argparse for dataset_name and attributor_name
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset_name', type=str, default='heloc')
    parser.add_argument('-m', '--model_name', type=str, default='lr')
    parser.add_argument('-p', '--property_name', type=str, default='accuracy')
    parser.add_argument('-a', '--attributor_name', type=str, default='tracin')
    parser.add_argument('-g', '--grouping', type=str, default='equal', choices=['equal', 'kmeans', 'repr_kmeans', 'grad_kmeans'])
    parser.add_argument('-n', '--n_trials', type=int, default=1)
    parser.add_argument('-i', '--invert', action='store_true')
    parser.add_argument('-c', '--cuda', action='store_true')
    
    # Start global timer
    original_start_time = time.time()

    # Parse arguments
    args = parser.parse_args()
    dataset_name, model_name = args.dataset_name, args.model_name
    attributor_name, grouping = args.attributor_name, args.grouping
    property_name = args.property_name
    invert = args.invert
    n_trials = args.n_trials
    device = 'cuda' if args.cuda and torch.cuda.is_available() else 'cpu'

    # Get save directory
    invert_str = '_invert' if invert else ''
    save_dir = f'experiments/results/retrain_outputs/{dataset_name}_{model_name}/{attributor_name}_{grouping}{invert_str}/'
    
    # Get dataset, train_model, and property_fn
    trainset, testset = LoadDataset(dataset_name, train=True), LoadDataset(dataset_name, train=False)
    property = properties_dict[property_name](testset, batch_size=512)

    # Get original model
    train_model, train_params = LoadTrainModel(dataset_name, model_name)
    model = train_model(trainset, device=device, verbose=True, use_model_cache=True)

    # Compute and save base property value
    base_property = property.forward(model, device=device)
    torch.save(base_property, f'{save_dir}/base_{property_name}.pt')

    if dataset_name in ['heloc', 'heloc_noisy']:
        group_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
    elif dataset_name in ['mnist', 'mnist_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
    elif dataset_name in ['cifar10', 'cifar10_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
    elif dataset_name in ['qnli', 'qnli_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

    # e.g. experiments/results/retrain_outputs/qnli_bert/influence_gn_equal/trial_1/test_outputs_group_size_1.pt
    for trial in range(1, n_trials+1):
        for group_size in group_sizes:
            test_outputs = torch.load(f'{save_dir}/trial_{trial}/test_outputs_group_size_{group_size}.pt')

            # Compute property values
            property_vals = torch.tensor([property.test_output_forward(test_out) for test_out in test_outputs])

            # Save property values
            torch.save(property_vals, f'{save_dir}/trial_{trial}/{property_name}_group_size_{group_size}.pt')

    # Print total time
    total_time = time.time() - original_start_time
    print(f"Total Time: {total_time:.3f}s")