import os
import argparse
import torch
from data import LoadDataset, LoadTrainModel
from data.utils import set_seed
from utils import save_json
from attributors import attributors_dict
from attributors.properties import CrossEntropyProperty
import time
import pandas as pd

attributor_params = {
    'trak': {'train_loss_fn': None, 'num_subsets': 100, 'subsampling_frac': 0.9, 'projection_dim': 16, 'soft_thresh_param': 0.001},
    'influence_lissa': {'train_loss_fn': None, 'damp': 0.001, 'repeat': 20, 'depth': 200, 'scale': 50},
    'influence_gn': {
        'heloc': {'train_loss_fn': None, 'projection_dim': 16, 'temperature': 1},
        'mnist': {'train_loss_fn': None, 'projection_dim': 32, 'temperature': 1},
        'cifar10': {'train_loss_fn': None, 'projection_dim': 64, 'temperature': 1},
        'qnli': {'train_loss_fn': None, 'projection_dim': 32, 'temperature': 1},
        'heloc_noisy': {'train_loss_fn': None, 'projection_dim': 16, 'temperature': 1},
        'mnist_noisy': {'train_loss_fn': None, 'projection_dim': 32, 'temperature': 1},
        'cifar10_noisy': {'train_loss_fn': None, 'projection_dim': 64, 'temperature': 1},
        'qnli_noisy': {'train_loss_fn': None, 'projection_dim': 32, 'temperature': 1},
    },
    'tracin': {'train_loss_fn': None, 'temperature': 1},
}

def _get_attributor_params(dataset_name, model_name, attributor_name, device):
    attr_params = attributor_params[attributor_name]
    if dataset_name in attr_params:
        attr_params = attr_params[dataset_name]
    if 'train_loss_fn' in attr_params:
        attr_params['train_loss_fn'] = LoadTrainModel(dataset_name, model_name)[1]['loss_fn']
    attr_params['use_model_cache'] = True
    attr_params['device'] = device
    return attr_params

def main():
    # Implement argparse for dataset_name and attributor_name
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--dataset_name', type=str, default='heloc')
    parser.add_argument('-m', '--model_name', type=str, default='lr')
    parser.add_argument('-a', '--attributor_name', type=str, default='influence_cg')
    parser.add_argument('-g', '--grouping', type=str, default='equal', choices=['equal', 'kmeans', 'repr_kmeans', 'grad_kmeans'])
    parser.add_argument('-n', '--n_trials', type=int, default=1)
    parser.add_argument('-c', '--cuda', action='store_true')

    # Start global timer
    original_start_time = time.time()

    # Parse arguments
    args = parser.parse_args()
    dataset_name, model_name = args.dataset_name, args.model_name
    attributor_name = args.attributor_name
    grouping, n_trials = args.grouping, args.n_trials
    device = 'cuda' if args.cuda and torch.cuda.is_available() else 'cpu'

    # Update attributor_params
    attributor_params = _get_attributor_params(dataset_name, model_name, attributor_name, device)

    # Get dataset, train_model, and property_fn
    trainset, testset = LoadDataset(dataset_name, train=True), LoadDataset(dataset_name, train=False)
    train_model, _ = LoadTrainModel(dataset_name, model_name)

    if dataset_name in ['heloc', 'heloc_noisy']:
        group_sizes = [1, 2, 4, 8, 16, 32, 64] if dataset_name == 'heloc_small' else [1, 2, 4, 8, 16, 32, 64, 128, 256]
        property = CrossEntropyProperty(testset, batch_size=None)
    elif dataset_name in ['mnist', 'mnist_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
        property = CrossEntropyProperty(testset, batch_size=1024)
    elif dataset_name in ['qnli', 'qnli_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
        property = CrossEntropyProperty(testset, batch_size=128)
    elif dataset_name in ['cifar10', 'cifar10_noisy']:
        group_sizes = [1, 4, 16, 64, 256, 1024]
        property = CrossEntropyProperty(testset, batch_size=512)
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

    # Save property_vals, num_datapoints_removed and plot
    save_dir = f'experiments/results/attr_scores/{dataset_name}_{model_name}/{attributor_name}_{grouping}/'
    os.makedirs(save_dir, exist_ok=True)

    # Ablate over group sizes and plot
    mean_attribution_times, mean_grouping_times = [], []
    print("Group Sizes:", group_sizes)
    for group_size in group_sizes:
        attribution_times = []

        # Get cluster_ids (load if already computed)
        cluster_dir = f'experiments/results/clusters/{dataset_name}_{model_name}/{grouping}/'
        group_ids = torch.load(os.path.join(cluster_dir, f'group_ids_size_{group_size}.pt'))

        for trial in range(1, n_trials+1):
            torch.cuda.empty_cache()
            print(f"\nTrial {trial}")
            print(f"\nGroup Size: {group_size}")

            # Intialize attributor
            if attributor_name in ['datamodels', 'trak', 'leave_one_out', 'random']:
                data_attributor = attributors_dict[attributor_name](dataset=trainset, group_ids=group_ids, train_model=train_model)
            elif attributor_name in ['influence_cg', 'influence_lissa', 'influence_lissa_hf', 'influence_gn', 'tracin', 'group_tracin']:
                model = train_model(trainset, device=device, verbose=True, use_model_cache=True)
                if attributor_name == 'group_tracin':
                    test_group_ids = torch.load(os.path.join(cluster_dir, f'test_group_ids_size_{group_size}.pt'))
                    attributor_params['test_group_ids'] = test_group_ids
                data_attributor = attributors_dict[attributor_name](dataset=trainset, group_ids=group_ids, model=model)

            # Compute and store attributions
            set_seed(42)
            print(f"Computing Group Attributions...")
            start_time = time.time()
            scores_per_group = data_attributor.compute_group_attributions(property, **attributor_params)
            attributions_time = time.time() - start_time
            attribution_times.append(attributions_time)
            print(f"Attributions Time: {attributions_time:.3f}s")

            # Save group IDs and scores
            group_ids_and_scores = {
                'group_ids': group_ids,
                'scores_per_group': list(scores_per_group.values())
            }
            os.makedirs(os.path.dirname(save_dir + 'scores_pt/'), exist_ok=True); os.makedirs(os.path.dirname(save_dir + 'scores_json/'), exist_ok=True)
            save_path = os.path.join(save_dir, f'scores_pt/group_scores_size_{group_size}_trial_{trial}.pt')
            torch.save(group_ids_and_scores, save_path); save_json(scores_per_group, save_path.replace('pt', 'json'))
            print(f"Saved group scores to {save_path}")

        # Save grouping and attribution times
        grouping_times = pd.read_csv(os.path.join(cluster_dir, f'times_size_{group_size}.csv'))['grouping_times'].tolist()
        grouping_times = grouping_times[:n_trials]
        print(f"Grouping Times: {grouping_times}")
        print(f"Attribution Times: {attribution_times}")

        df = pd.DataFrame({'grouping_times': grouping_times, 'attribution_times': attribution_times,
                           'total_times': [grouping_times[i] + attribution_times[i] for i in range(n_trials)]})
        df.to_csv(os.path.join(save_dir, f'times_size_{group_size}.csv'), index=False)

        # Print mean time
        mean_grouping_time = sum(grouping_times) / len(grouping_times)
        mean_grouping_times.append(mean_grouping_time)
        mean_attribution_time = sum(attribution_times) / len(attribution_times)
        mean_attribution_times.append(mean_attribution_time)
        print(f"Mean Grouping Time (Precomputed): {mean_grouping_time:.3f}s")
        print(f"Mean Attribution Time: {mean_attribution_time:.3f}s")

    # Save mean grouping and attribution times
    df = pd.DataFrame({'group_size': group_sizes, 'mean_grouping_time': mean_grouping_times, 'mean_attribution_time': mean_attribution_times,
                       'mean_total_time': [mean_grouping_times[i] + mean_attribution_times[i] for i in range(len(group_sizes))]})
    df.to_csv(os.path.join(save_dir, 'mean_times.csv'), index=False)

    # Print total time
    total_time = time.time() - original_start_time
    print(f"Total Time: {total_time:.3f}s")

if __name__ == "__main__":
    main()
