import argparse
import numpy as np
import os
import pickle
import torch

def compute_forgetting_statistics(diag_stats,ignore):
    # presentations_needed_to_learn = {}
    unlearned_per_presentation = {}
    # margins_per_presentation = {}
    first_learned = {}
    for example_id, example_stats in diag_stats.items():

        # Skip 'train' and 'test','acc' keys of diag_stats
        if not isinstance(example_id, str):
            
            presentation_acc = np.array(example_stats[0])
            if ignore == 'ignore':
                # print("ignore full learn")
                presentation_acc = presentation_acc[1:]
                           
            transitions = presentation_acc[1:] - presentation_acc[:-1]
            # np.where(transitions == -1) = array([..],dtype =.. )
            # np.where(transitions == -1)[0] = [.]
            if len(np.where(transitions == -1)[0]) > 0:
                unlearned_per_presentation[example_id] = np.where(transitions == -1)[0] + 2
            else:
                unlearned_per_presentation[example_id] = []

            if len(np.where(presentation_acc == 1)[0]) > 0:
                first_learned[example_id] = np.where(presentation_acc == 1)[0][0] 
            else:
                first_learned[example_id] = np.nan
                
    return unlearned_per_presentation,first_learned


# Sorts examples by number of forgetting counts during training, in ascending order
# If an example was never learned, it is assigned the maximum number of forgetting counts
# If multiple training runs used, sort examples by the sum of their forgetting counts over all runs
#
# unlearned_per_presentation_all: list of dictionaries, one per training run
# first_learned_all: list of dictionaries, one per training run
# npresentations: number of training epochs
#
# Returns 2 numpy arrays containing the sorted example ids and corresponding forgetting counts
#
def sort_examples_by_forgetting(unlearned_per_presentation, first_learned_all, npresentations):

    # Initialize lists
    example_original_order = []
    example_stats = []

    for example_id in unlearned_per_presentation.keys():
        # Add current example to lists
        example_original_order.append(example_id)
        example_stats.append(0)
        stats = unlearned_per_presentation[example_id]

        # If example was never learned during current training run, add max forgetting counts
        if np.isnan(first_learned_all[example_id]):
            example_stats[-1] += npresentations
        else:
            example_stats[-1] += len(stats)

    # print('Number of unforgettable examples: {}'.format(len(np.where(np.array(example_stats) == 0)[0])))
    return np.array(example_original_order)[np.argsort(example_stats)], np.sort(example_stats)


def order_examples_of_forget(loaded,ignore):


    ordered_examples_idx_all = [[], []]
    ordered_forget_values_all = [[], []]

    #loaded = torch.load(f'./save/learning_domain_digtialnet_forget_event.pth')
    # print(loaded[0][1][6666])
    for i in range(len(loaded)):
        for j in range(len(loaded[i])):
            # print(i, " client: ", j, ' ', len(loaded[i][j]))
            # print(len(loaded[i][j][0][0]))
            unlearned_per_presentation, first_learned = compute_forgetting_statistics(loaded[i][j],ignore)
            ordered_examples_idx, ordered_forget_values = sort_examples_by_forgetting(unlearned_per_presentation,first_learned, -1)
            ordered_examples_idx_all[i].append(ordered_examples_idx)
            ordered_forget_values_all[i].append(ordered_forget_values)
 
    # print(ordered_forget_values_all[0][0])
    return ordered_examples_idx_all, ordered_forget_values_all


def count_acc_for_domain_class(args,idx,all_client_targets,example_state):
    
    acc = []
 
    # for j in range(len(example_state)):
    #     accs = example_state[j]['acc'][1]   
    #     if idx == 0:
    #         acc.append(accs)
    #     else:    
    #         acc.append(accs[1:])
    #     # print(load[0][i]['acc'][1])
    # acc_list = []
    # for i in range(len(acc[0])):
    #     acc_total = 0
    #     for j in range(len(acc)):
    #         if idx == 0  or j != args.unlearning_client:
    #             acc_total += acc[j][i]
    #     acc_list.append(acc_total/3)
    #     # print(i , acc_total/4)
    # if idx != 0:
    #     max_idx += 1
    # print(max_idx)
    
    # all_client_targets = []
    # for client_idx in range(args.num_users):
    #     client_targets = {}
    #     for data, target, index in data_Loader[client_idx]:
    #         target = target.tolist()
    #         index = index.tolist()
    #         client_targets.update(dict(zip(index, target)))
    #     all_client_targets.append(client_targets)

    target_count = len(set(list(all_client_targets[0].values())))
    all_client_err = []
    for i, client in enumerate(example_state):
        client.pop("acc")
        client_err = []
        for _ in range(target_count):
            client_err.append(0)

        for idx, value in client.items():
            if value[0][-1] == 0:
                client_err[(all_client_targets[i][idx])] += 1
        total = len(client)
        all_client_err.append(np.array(client_err)/total*100)


    return all_client_err
    # for i in range(5):
    #     print(np.sum(all_client_err[i]))
    #     print(100 - np.sum(all_client_err[i]))



