import os
import argparse
import torch
from data import LoadDataset, LoadTrainModel
from attributors.properties import CrossEntropyProperty
from attributors.evaluators import remove_and_retrain_evaluator_num_points
from utils import load_json
import time

if __name__ == '__main__':
    # Implement argparse for dataset_name and attributor_name
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset_name', type=str, default='heloc')
    parser.add_argument('-m', '--model_name', type=str, default='lr')
    parser.add_argument('-a', '--attributor_name', type=str, default='influence_cg')
    parser.add_argument('-g', '--grouping', type=str, default='equal', choices=['equal', 'kmeans', 'repr_kmeans', 'grad_kmeans'])
    parser.add_argument('-gs', '--group_size', type=int, default=1)
    parser.add_argument('-t', '--n_trials', type=int, default=5)
    parser.add_argument('-i', '--invert', action='store_true')
    parser.add_argument('-c', '--cuda', action='store_true')
    parser.add_argument('-u', '--use_model_cache', action='store_true')
    
    # Start global timer
    original_start_time = time.time()

    # Parse arguments
    args = parser.parse_args()
    dataset_name, model_name = args.dataset_name, args.model_name
    attributor_name = args.attributor_name
    grouping, invert = args.grouping, args.invert
    n_trials, use_model_cache = args.n_trials, args.use_model_cache
    group_size = args.group_size
    device = 'cuda' if args.cuda and torch.cuda.is_available() else 'cpu'
    
    # Get dataset, train_model, and property_fn
    trainset, testset = LoadDataset(dataset_name, train=True), LoadDataset(dataset_name, train=False)
    train_model, train_params = LoadTrainModel(dataset_name, model_name)
    property = CrossEntropyProperty(testset, batch_size=2048)

    # Save property_vals, num_datapoints_removed and plot
    invert_str = '_invert' if invert else ''
    save_dir = f'experiments/results/retrain_outputs/{dataset_name}_{model_name}/{attributor_name}_{grouping}{invert_str}/'
    os.makedirs(save_dir, exist_ok=True)

    # Fix number of points to remove
    percentages = [25, 50, 75] if invert else [1, 5, 10, 20]
    print(f"Percentages: {percentages}")
    num_points_to_remove = [int(len(trainset) * i / 100) for i in percentages]
    print(f"Num Points to Remove: {num_points_to_remove}")

    # Ablate over group sizes
    for trial in range(1, n_trials + 1):
        trial_time = time.time()
        # for group_size in group_sizes:
        # Load attributions
        print(f"\nGroup Size: {group_size}, Trial: {trial}")
        group_ids_path = f'experiments/results/clusters/{dataset_name}_{model_name}/{grouping}/group_ids_size_{group_size}.pt'
        group_ids = torch.load(group_ids_path)
        
        if attributor_name == 'random':
            unique_group_ids = torch.unique(torch.tensor(group_ids))
            scores_per_group = {group_id: torch.rand(1).item() for group_id in torch.sort(unique_group_ids)[0]}
        else:
            load_dir = f'experiments/results/attr_scores/{dataset_name}_{model_name}/{attributor_name}_{grouping}/scores_json/'
            group_str = f'group_scores_size_{group_size}_trial_{trial}.json'
            scores_per_group = load_json(os.path.join(load_dir, group_str), convert_keys_to_int=True)
        sorted_unique_group_ids = sorted(scores_per_group, key=lambda x: scores_per_group[x], reverse=not invert)

        # Evaluate attributions via remove_and_retrain_evaluator
        train_masks, test_outputs = remove_and_retrain_evaluator_num_points(
            trainset, testset,
            sorted_unique_group_ids,
            group_ids, train_model,
            num_points_to_remove=num_points_to_remove,
            device=device,
            use_model_cache=use_model_cache,
            verbose=True,
            seed=trial,
        )

        # Compute property values
        property_vals = torch.tensor([property.test_output_forward(test_out) for test_out in test_outputs])

        # Save train_masks, test outputs, group_ids and sorted_unique_group_ids
        trial_save_dir = os.path.join(save_dir, f'trial_{trial}')
        if not os.path.exists(trial_save_dir): os.makedirs(trial_save_dir)
        torch.save(train_masks, os.path.join(trial_save_dir, f'train_masks_group_size_{group_size}.pt'))
        torch.save(test_outputs, os.path.join(trial_save_dir, f'test_outputs_group_size_{group_size}.pt'))
        torch.save(property_vals, os.path.join(trial_save_dir, f'property_vals_group_size_{group_size}.pt'))
        torch.save(torch.tensor(group_ids).int(), os.path.join(trial_save_dir, f'group_ids_group_size_{group_size}.pt'))
        torch.save(torch.tensor(sorted_unique_group_ids).int(), os.path.join(trial_save_dir, f'group_rankings_group_size_{group_size}.pt'))

        # Print trial time
        print(f"Trial Time: {time.time() - trial_time:.3f}s")

    # Print total time
    total_time = time.time() - original_start_time
    print(f"Total Time: {total_time:.3f}s")