import copy
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

import wandb
from dataset import load_dataset
from submod import faciliy_location_order
from utils import load_model, plot_confusion_matrix, save_activations


def silhouette(args, preds):
    silhouette_scores = []
    for i in range(2, args.max_clusters+1):
        clusterer = KMeans(n_clusters=i,
                        random_state=args.seed,
                        n_init=10)
        cluster_labels = clusterer.fit_predict(preds)
        silhouette_avg = silhouette_score(preds, cluster_labels)
        print("For n_clusters =", i,
            "The average silhouette_score is :", silhouette_avg)
        silhouette_scores.append(silhouette_avg)
    num_clusters = np.argmax(silhouette_scores) + 2
    score = np.max(silhouette_scores)

    return num_clusters, score


def clustering(args, group_array, preds, labels):
    """Cluster the training set based on the predicted logits."""
    if args.dataset not in ['cifar100sup', 'cifar10', 'imagenet', 'cmnist', 'balance_cmnist']:
        args.num_clusters = args.num_groups // args.num_classes
    if args.cluster_method == 'kmeans':
        if args.cluster_all:
            args.logger.info('Clustering with k-means into {} clusters'.format(args.num_clusters))
            if args.silhouette:
                args.logger.info('Using silhouette score to choose number of clusters')
                args.num_clusters, args.silhouette_score = silhouette(args, preds)
                args.logger.info('Number of clusters: {}'.format(args.num_clusters))
                args.logger.info('Silhouette score: {}'.format(args.silhouette_score))
            clusterer = KMeans(n_clusters=args.num_clusters,
                            random_state=args.seed,
                            n_init=10)
            cluster_labels = clusterer.fit_predict(preds)
        else:
            cluster_labels = np.zeros(len(group_array), dtype=np.int64)
            if args.silhouette:
                silhouette_scores = []
            for i in range(args.num_classes):
                class_indices = np.where(labels == i)[0]
                class_preds = preds[class_indices]
                if args.silhouette:
                    args.logger.info('Using silhouette score to choose number of clusters for class {}'.format(i))
                    num_clusters, score = silhouette(args, class_preds)
                    args.logger.info('Number of clusters: {}'.format(num_clusters))
                    args.logger.info('Silhouette score: {}'.format(score))
                    silhouette_scores.append(score)
                    args.num_clusters = max(num_clusters, args.num_clusters)
                clusterer = KMeans(n_clusters=args.num_clusters,
                                random_state=args.seed,
                                n_init=10)
                cluster_labels[class_indices] = clusterer.fit_predict(class_preds)
            if args.silhouette:
                args.logger.info('Silhouette scores: {}'.format(silhouette_scores))
                args.silhouette_score = np.min(silhouette_scores)
        cluster_labels = cluster_labels + labels * args.num_clusters
    else:
        if args.cluster_all:
            cluster_labels = get_orders_and_weights(
                len(np.unique(group_array)), preds, 'euclidean', 
                y=np.zeros(len(group_array)), equal_num=args.equal, return_cluster=True, args=args)
            cluster_labels = cluster_labels + labels * args.num_clusters
        else:
            cluster_labels = get_orders_and_weights(
                len(np.unique(group_array)), preds, 'euclidean', 
                y=labels, equal_num=args.equal, return_cluster=True, args=args)
            cluster_labels = list(cluster_labels)                                                              
            cluster_labels = np.concatenate(cluster_labels)

            indices_by_class = np.concatenate([np.where(np.array(labels)==c)[0] for c in range(args.num_classes)])
            cluster_labels = cluster_labels[np.argsort(indices_by_class)]

    return cluster_labels


def misclass(preds, labels, args):
    """For each class, assign correct samples to the same cluster and incorrect samples to different clusters."""
    cluster_labels = np.zeros(len(labels), dtype=np.int32)
    if args.dataset not in ['cifar100sup', 'cifar10', 'imagenet', 'cmnist', 'balance_cmnist']:
        args.num_clusters = args.num_groups // args.num_classes
    for i in range(args.num_classes):
        class_indices = np.where(labels == i)[0]

        # print the distribution of predicted classes
        class_preds = preds[class_indices].argmax(axis=1)
        print(f'Class {i}: {np.bincount(class_preds)}')

        # use argmax to get the predicted class
        cluster_labels[class_indices[class_preds == i]] = i * args.num_clusters
        cluster_labels[class_indices[class_preds != i]] = i * args.num_clusters + 1

        # print the two clusters
        print(f'Class {i} cluster sizes: {np.bincount(cluster_labels[class_indices])}')
    return cluster_labels


def eiil(logits, labels, args):
    """Learn soft environment assignment."""

    scale = torch.tensor(1.).cuda().requires_grad_()
    train_criterion = nn.CrossEntropyLoss(reduction='none')
    loss = train_criterion(torch.from_numpy(logits).cuda() * scale, torch.from_numpy(labels).long().cuda())

    env_w = torch.randn(len(logits)).cuda().requires_grad_()
    optimizer = optim.Adam([env_w], lr=args.eiil_lr)

    args.logger.info('learning soft environment assignments')
    for i in tqdm(range(args.eiil_steps)):
        # penalty for env a
        lossa = (loss.squeeze() * env_w.sigmoid()).mean()
        grada = autograd.grad(lossa, [scale], create_graph=True)[0]
        penaltya = torch.sum(grada**2)
        # penalty for env b
        lossb = (loss.squeeze() * (1-env_w.sigmoid())).mean()
        gradb = autograd.grad(lossb, [scale], create_graph=True)[0]
        penaltyb = torch.sum(gradb**2)
        # negate
        npenalty = - torch.stack([penaltya, penaltyb]).mean()

        optimizer.zero_grad()
        npenalty.backward(retain_graph=True)
        optimizer.step()

    # sigmoid to get env assignment
    group_labels = env_w.sigmoid() > .5
    group_labels = group_labels.int().detach().cpu().numpy()
    group_labels[labels==1] += 2
    
    return group_labels


def ssa(args, train_dataset):
    train_criterion = nn.CrossEntropyLoss()
    
    # splitt the validation set into two parts
    args.logger.info('splitting the labeled validation set')
    val_dataset = load_dataset(args, split='val', group=None, augment=False)
    group_labels = torch.tensor(val_dataset.group_array).long()
    spurious_labels = torch.tensor(val_dataset.spurious).long()
    labeled_train, labeled_val = torch.utils.data.random_split(val_dataset, [len(val_dataset)//2, len(val_dataset) - len(val_dataset)//2])

    # get the smallest group in labeled train set
    group_labels = group_labels[labeled_train.indices]
    group_counts = np.bincount(group_labels.numpy())
    smallest_group = group_counts.argmin()

    labeled_train_loader = torch.utils.data.DataLoader(labeled_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    labeled_val_loader = torch.utils.data.DataLoader(labeled_val, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # split the training set into ars.num_splits folds
    args.logger.info('splitting the unlabeled training set')
    unlabeled_splits = torch.utils.data.random_split(train_dataset, [len(train_dataset)//args.num_splits for _ in range(args.num_splits-1)] + [len(train_dataset) - (args.num_splits-1) * (len(train_dataset)//args.num_splits)])

    # create a tensor to store the pseudo labels
    pseudo_labels = np.zeros(len(train_dataset), dtype=np.int32)

    for k in range(args.num_splits):
        # load model
        model, optimizer, scheduler = load_model(args, infer=False, t_total=args.num_iters)

        # leave the k-th fold out
        args.logger.info('leave the {}-th fold out'.format(k))
        unlabeled_train = torch.utils.data.ConcatDataset([unlabeled_splits[i] for i in range(args.num_splits) if i != k])
        unlabeled_train_loader = torch.utils.data.DataLoader(unlabeled_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

        labeled_train_iter = iter(labeled_train_loader)
        unlabeled_train_iter = iter(unlabeled_train_loader)

        best_val_acc = 0
        best_model = copy.deepcopy(model.state_dict())

        for it in range(args.num_iters):
            args.logger.info('iteration {}'.format(it))
            # get one batch of labeled data
            try:            
                labeled_train_batch = next(labeled_train_iter)
            except StopIteration:
                labeled_train_iter = iter(labeled_train_loader)
                labeled_train_batch = next(labeled_train_iter)

            try:
                unlabeled_train_batch = next(unlabeled_train_iter)
            except StopIteration:
                unlabeled_train_iter = iter(unlabeled_train_loader)
                unlabeled_train_batch = next(unlabeled_train_iter)

            data_labeled, _, index_labeled = labeled_train_batch
            data_unlabeled, target_unlabeled, _ = unlabeled_train_batch
            target_unlabeled = target_unlabeled.cuda()

            group = spurious_labels[index_labeled].cuda()

            # get the logits
            labeled_train_logits = model(data_labeled.cuda())
            unlabeled_train_logits = model(data_unlabeled.cuda())

            # get softmax
            labeled_train_softmax = F.softmax(labeled_train_logits, dim=1)
            unlabeled_train_softmax = F.softmax(unlabeled_train_logits, dim=1)

            # calculate the supervised loss
            labeled_train_loss = train_criterion(labeled_train_logits, group)

            # get the number of labeled data in the smallest group
            num_labeled_smallest_group = (group_labels == smallest_group).sum().item()
            smallest_group_class = smallest_group // (args.num_groups // args.num_classes)
            smallest_group_spurious = smallest_group % (args.num_groups // args.num_classes)

            num_unlabeled_smallest_group = (unlabeled_train_softmax[target_unlabeled==smallest_group_class, smallest_group_spurious] > args.smallest_group_conf_thresh).sum().item()
            group_size = num_labeled_smallest_group + num_unlabeled_smallest_group
            print('group size: {}'.format(group_size))

            # calculate the unsupervised loss
            unlabeled_train_loss = torch.tensor(0.0).cuda()
            for group_idx in range(args.num_groups):
                group_class = group_idx // (args.num_groups // args.num_classes)
                group_spurious = group_idx % (args.num_groups // args.num_classes)

                # get the most confident unlabeled data in the group
                unlabeled_train_softmax_group = unlabeled_train_softmax[target_unlabeled==group_class, group_spurious]
                unlabeled_train_softmax_group, unlabeled_train_softmax_group_idx = unlabeled_train_softmax_group.sort(descending=True)

                # get the number of labeled data in the group
                num_labeled_group = (group_labels == group_idx).sum().item()
                num_unlabeled_group = max(group_size - num_labeled_group, 0)

                print('group_idx: {}, num_labeled_group: {}, num_unlabeled_group: {}'.format(group_idx, num_labeled_group, num_unlabeled_group))

                if num_unlabeled_group > 0:
                    # get the unlabeled logits in the group
                    unlabeled_train_logits_group = unlabeled_train_logits[target_unlabeled==group_class, :]

                    # calculate the loss with the most confident unlabeled data
                    unlabeled_train_loss += train_criterion(unlabeled_train_logits_group[unlabeled_train_softmax_group_idx[:num_unlabeled_group]], group_spurious * torch.ones(num_unlabeled_group).long().cuda())

            # calculate the total loss
            total_loss = labeled_train_loss + unlabeled_train_loss / args.num_groups

            # update the model
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            scheduler.step()

            # validation
            if it % args.val_freq == 0:
                # print the log
                args.logger.info('iter: {}, labeled_train_loss: {:.4f}, unlabeled_train_loss: {:.4f}, total_loss: {:.4f}'.format(it, labeled_train_loss.item(), unlabeled_train_loss.item(), total_loss.item()))
                args.logger.info('validating the model')
                model.eval()
                with torch.no_grad():
                    # get the logits
                    val_logits = []
                    for val_batch in labeled_val_loader:
                        val_logits.append(model(val_batch[0].cuda()))
                    val_logits = torch.cat(val_logits, dim=0)

                    val_labels = torch.tensor(val_dataset.group_array).long()[labeled_val.indices]

                    # get the accuracy
                    val_acc = (val_logits.argmax(dim=1) == val_labels.cuda()).sum().item() / len(val_labels)
                    args.logger.info('val_acc: {:.4f}'.format(val_acc))
                model.train()

                # save the best model
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_model = copy.deepcopy(model.state_dict())

        # label the left out fold
        args.logger.info('labeling the {}-th fold'.format(k))
        model.load_state_dict(best_model)
        unlabeled_train_loader = torch.utils.data.DataLoader(unlabeled_splits[k], batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
        model.eval()
        with torch.no_grad():
            unlabeled_train_logits = []
            for unlabeled_train_batch in unlabeled_train_loader:
                unlabeled_train_logits.append(model(unlabeled_train_batch[0].cuda()))
        unlabeled_train_logits = torch.cat(unlabeled_train_logits, dim=0)
        unlabeled_train_preds = torch.argmax(unlabeled_train_logits, dim=1).cpu().numpy()
        pseudo_labels[unlabeled_splits[k].indices] = unlabeled_train_preds

    # make the pseudo labels for different classes different
    for class_idx in range(args.num_classes):
        pseudo_labels[np.array(train_dataset.targets)==class_idx] += class_idx * args.num_groups // args.num_classes

    # return the pseudo labels
    return pseudo_labels


def get_orders_and_weights(
    B, X, metric, y=None, weights=None, equal_num=False, mode='dense', num_n=None, 
    stop_zero_gain=True, separate_rep=False, data_rep=None, return_cluster=False,
    args=None,
):
    """
    Ags
    - X: np.array, shape [N, d]
    - B: int, number of points to select
    - metric: str, one of ['cosine', 'euclidean'], for similarity
    - y: np.array, shape [N], integer class labels for C classes
      - if given, chooses B / C points per class, B must be divisible by C
    - outdir: str, path to output directory, must already exist
    Returns
    - order_mg/_sz: np.array, shape [B], type int64
      - *_mg: order points by their marginal gain in FL objective (largest gain first)
      - *_sz: order points by their cluster size (largest size first)
    - weights_mg/_sz: np.array, shape [B], type float32, sums to 1
    """
    N = X.shape[0]
    if y is None:
        y = np.zeros(N, dtype=np.int32)  # assign every point to the same class
    classes = np.unique(y)
    # classes = classes.astype(np.int32).tolist()
    C = len(classes)  # number of classes

    if equal_num:
        class_nums = [sum(y == c) for c in classes]
        num_per_class = int(np.ceil(B / C)) * np.ones(len(classes), dtype=np.int32)
        minority = class_nums < np.ceil(B / C)
        if sum(minority) > 0:
            extra = sum([max(0, np.ceil(B / C) - class_nums[c]) for c in classes])
            for c in classes[~minority]:
                num_per_class[c] += int(np.ceil(extra / sum(minority)))
    else:
        num_per_class = np.int32(
            np.ceil(np.divide([sum(y == i) for i in classes], N) * B)
        )

    print(f"Greedy: selecting {num_per_class} elements")

    greedy_times, similarity_times, cluster_all = zip(
        *map(
            lambda c: faciliy_location_order(
                c[1], X, y, metric, num_per_class[c[0]], weights, mode, num_n, stop_zero_gain, 
                separate_rep, data_rep, args
            ),
            enumerate(classes),
        )
    )
    print(
        f"time (sec) for computing facility location: {greedy_times} similarity time {similarity_times}",
    )


    return cluster_all


def infer_one_step(args, model, train_dataset, train_val_loader, val_loader, steps, test_loader=None):
    cluster_file = os.path.join(args.checkpoint_path, 
                                        '{}{}/{}_seed{}'.format(args.save_unit, '-'.join([str(step) for step in steps]), args.infer, args.seed))


    if args.infer == 'cluster':
        if args.dataset not in ['cifar100sup', 'cifar10', 'imagenet', 'cmnist', 'balance_cmnist']:
            cluster_file = os.path.join(cluster_file, '{}clusters_{}_{}_labels.pt'.format(args.num_clusters, args.cluster_metric, args.cluster_method))
        elif args.sep_conf:
            cluster_file = os.path.join(cluster_file, 'cluster_{}_{}_labels_conf_{}.pt'.format(args.cluster_metric, args.cluster_method, args.conf_thresh))
        else:
            cluster_file = os.path.join(cluster_file, 'cluster_{}_{}_labels.pt'.format(args.cluster_metric, args.cluster_method))

        if len(steps) > 0:
            traject = []

            ckpt_indices = steps

            for ckpt_idx in ckpt_indices:  

                # initialize an empty list to store the gradients
                predictions = []

                # compute the gradient by subtracting the labels from the predictions
                if ckpt_idx != 0:
                    try:
                        ckpt_path = os.path.join(args.checkpoint_path, 
                                                    f'{args.save_unit}{ckpt_idx}/checkpoint_seed{args.seed}.pt')
                        args.logger.info("=> loading checkpoint from '{}'".format(ckpt_path))
                        checkpoint = torch.load(ckpt_path)
                    except:
                        ckpt_path = os.path.join(args.checkpoint_path, 
                                                    f'{args.save_unit}{ckpt_idx}/checkpoint_seed{args.seed}.pth')
                        args.logger.info("=> loading checkpoint from '{}'".format(ckpt_path))
                        checkpoint = torch.load(ckpt_path)
                    if args.infer_arch != '':
                        model, _, _ = load_model(args, infer=True)
                    model.load_state_dict(checkpoint['model_state_dict'])

                    if args.freeze:
                        # freeze the model except the last layer
                        for name, param in model.named_parameters():
                            if 'fc' not in name:
                                param.requires_grad = False

                # save embeddings and predictions
                activation_path = os.path.join(args.checkpoint_path, 
                                            '{}{}/activations_seed{}.pt'.format(args.save_unit, ckpt_idx, args.seed))    
                if not os.path.exists(activation_path):
                    os.makedirs(os.path.dirname(activation_path), exist_ok=True)
                    embeds, preds, outputs, labels = save_activations(model, train_val_loader, args)
                    torch.save({
                        'embeds': embeds,
                        'preds': preds,
                        'outputs': outputs,
                        'labels': labels,
                    }, activation_path)
                else:
                    checkpoint = torch.load(activation_path)
                    preds = checkpoint['preds']
                    labels = checkpoint['labels']
                    embeds = checkpoint['embeds']
                    outputs = checkpoint['outputs']

                # check if each prediction is correct
                correct = np.argmax(preds, axis=1) == labels

                # log accuracy for each group at the inference stage
                for group in np.unique(train_dataset.group_array):
                    args.logger.info('Group {} accuracy: {:.2f}'.format(group, 
                        np.mean(correct[train_dataset.group_array == group])))
                        
                # append the predictions to the list
                if args.cluster_metric == 'pred':
                    predictions.append(preds)
                elif args.cluster_metric == 'embed':
                    predictions.append(embeds)
                elif args.cluster_metric == 'logit':
                    predictions.append(outputs)
                elif args.cluster_metric == 'conf':
                    predictions.append(np.amax(preds, axis=1, keepdims=True))
                elif args.cluster_metric == 'true_conf':
                    predictions.append(preds[np.arange(preds.shape[0]), labels][:, np.newaxis])
                elif args.cluster_metric == 'grad':
                    g0 = preds - np.eye(preds.shape[1])[labels]
                    g0_expand = np.repeat(g0, np.shape(embeds)[1], axis=1)
                    predictions.append(g0_expand * np.tile(embeds, np.shape(preds)[1]))
                else:
                    raise ValueError('Invalid metric for trajectory inference')
                
                predictions = np.array(predictions)
                
                traject.append(predictions)
            
            traject = np.concatenate(traject, axis=0)
            if args.sep_infer:              
                traject = traject[labels, np.arange(traject.shape[1])]
            else:
                traject = traject.reshape(-1, traject.shape[-1])
        else:
            embeds, preds, outputs, labels = save_activations(model, train_val_loader, args)
            if args.cluster_metric == 'pred':
                traject = preds
            elif args.cluster_metric == 'embed':
                traject = embeds
            elif args.cluster_metric == 'logit':
                traject = outputs
            elif args.cluster_metric == 'conf':
                traject = np.amax(preds, axis=1, keepdims=True)
            elif args.cluster_metric == 'true_conf':
                traject = preds[np.arange(preds.shape[0]), labels][:, np.newaxis]
            elif args.cluster_metric == 'grad':
                g0 = preds - np.eye(preds.shape[1])[labels]
                g0_expand = np.repeat(g0, np.shape(embeds)[1], axis=1)
                traject = g0_expand * np.tile(embeds, np.shape(preds)[1])
            else:
                raise ValueError('Invalid metric for trajectory inference')
            
        groups = train_dataset.group_array
            
        if test_loader is not None:
            args.logger.info('Inferring groups with the test set')
            test_embeds, test_preds, test_outputs, test_labels = save_activations(model, test_loader, args)
            if args.cluster_metric == 'pred':
                test_traject = test_preds
            elif args.cluster_metric == 'embed':
                test_traject = test_embeds
            elif args.cluster_metric == 'logit':
                test_traject = test_outputs
            elif args.cluster_metric == 'conf':
                test_traject = np.amax(test_preds, axis=1, keepdims=True)
            elif args.cluster_metric == 'true_conf':
                test_traject = test_preds[np.arange(test_preds.shape[0]), test_labels][:, np.newaxis]
            elif args.cluster_metric == 'grad':
                g0 = test_preds - np.eye(test_preds.shape[1])[test_labels]
                g0_expand = np.repeat(g0, np.shape(test_embeds)[1], axis=1)
                test_traject = g0_expand * np.tile(test_embeds, np.shape(test_preds)[1])
            else:
                raise ValueError('Invalid metric for trajectory inference')
            
            # save test activations
            test_activation_path = os.path.join(args.checkpoint_path,
                                            '{}{}/test_activations_seed{}.pt'.format(args.save_unit, ckpt_idx, args.seed))
            if not os.path.exists(test_activation_path):
                os.makedirs(os.path.dirname(test_activation_path), exist_ok=True)
                torch.save({
                    'embeds': test_embeds,
                    'preds': test_preds,
                    'outputs': test_outputs,
                    'labels': test_labels,
                }, test_activation_path)
            
            traject = np.concatenate([traject, test_traject], axis=0)
            groups = np.concatenate([groups, test_loader.dataset.group_array], axis=0)
            labels = np.concatenate([labels, test_labels], axis=0)

        if args.visualize:
            visualize_val(args, val_loader, model, steps)

        if not os.path.exists(cluster_file) or not args.reuse_clusters or args.visualize or test_loader is not None:
            os.makedirs(os.path.dirname(cluster_file), exist_ok=True)
            args.logger.info('Inferring groups with clustering algorithm: {}'.format(args.cluster_method))
            cluster_labels = clustering(args, groups, traject, labels)
            if test_loader is not None:
                train_cluster_labels = cluster_labels[:len(train_dataset)]
                test_cluster_labels = cluster_labels[len(train_dataset):]
                cluster_labels = train_cluster_labels
                labels = labels[:len(train_dataset)]
                test_cluster_file = cluster_file.replace('.pt', '_test.pt')
                torch.save(test_cluster_labels, test_cluster_file)
                test_misclass = misclass(test_preds, test_labels, args)
                test_misclass = test_misclass - args.num_clusters * test_labels

                # plot confusion matrix for misclassified samples and test clusters
                if args.visualize:
                    fig, test_cluster_labels = plot_confusion_matrix(test_cluster_labels, test_misclass, args)
                    plt.savefig("{}/confusion_test_{}{}.png".format(args.save_dir, args.save_unit, '-'.join([str(step) for step in steps])), dpi=50)
                    plt.close()
                
                # for each class, find the smallest cluster and calculate the recall of misclassified samples
                log_dict = {}
                recalls = []
                for i in range(args.num_classes):
                    class_indices = np.where(test_labels == i)[0]
                    class_misclass = test_misclass[class_indices]
                    class_cluster_labels = test_cluster_labels[class_indices]
                    
                    # find the smallest cluster
                    class_clusters, class_counts = np.unique(class_cluster_labels, return_counts=True)
                    smallest_cluster = class_clusters[np.argmin(class_counts)]

                    # calculate recall for the smallest cluster
                    smallest_cluster_indices = np.where(class_cluster_labels == smallest_cluster)[0]
                    smallest_cluster_misclass = class_misclass[smallest_cluster_indices]
                    if np.sum(class_misclass) > 0:
                        recall = np.sum(smallest_cluster_misclass) / np.sum(class_misclass)
                        log_dict['test_misclass_recall/class_{}'.format(i)] = recall
                        recalls.append(recall)
                        args.logger.info('Recall for class {}: {}'.format(i, recall))
                    log_dict['test_acc/class_{}'.format(i)] = 1 - np.sum(class_misclass) / len(class_misclass)
                
                # calculate average recall
                log_dict['test_misclass_recall/avg'] = np.mean(recalls)
                args.logger.info('Average recall: {}'.format(log_dict['test_misclass_recall/avg']))
                log_dict['test_misclass_recall/std'] = np.std(recalls)
                args.logger.info('Standard deviation of recall: {}'.format(log_dict['test_misclass_recall/std']))
                log_dict['test_acc/avg'] = 1 - np.sum(test_misclass) / len(test_misclass)
                args.logger.info('Average accuracy: {}'.format(log_dict['test_acc/avg']))

                # save log to wandb
                if args.use_wandb:
                    log_dict['step'] = steps[-1]
                    wandb.log(log_dict)

                # split test loader based on cluster labels
                test_dataset = test_loader.dataset
                test_loader = []
                for i in range(args.num_clusters):
                    cluster_indices = np.where(test_cluster_labels == i)[0]
                    cluster_dataset = Subset(test_dataset, cluster_indices)
                    cluster_loader = DataLoader(cluster_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
                    test_loader.append(cluster_loader)

            # save cluster labels
            torch.save(cluster_labels, cluster_file)
        
        else:
            cluster_labels = torch.load(cluster_file)
    elif args.infer == 'ssa':
        cluster_file = os.path.join(args.checkpoint_path, 
                                    '{}_seed{}'.format(args.infer, args.seed))
        cluster_file = os.path.join(cluster_file, 'cluster_labels.pt')
        if not os.path.exists(cluster_file):
            os.makedirs(os.path.dirname(cluster_file), exist_ok=True)
            cluster_labels = ssa(args, train_dataset)
            # save cluster labels
            torch.save(cluster_labels, cluster_file)
        cluster_labels = torch.load(cluster_file)
        labels = train_dataset.targets
    else:
        if len(args.infer_steps) > 0:
            ckpt_path = os.path.join(args.checkpoint_path, '{}{}/checkpoint_seed{}.pt'.format(args.save_unit, args.infer_steps[0], args.seed))
            # find checkpoint with infer_step            
            checkpoint = torch.load(ckpt_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            args.logger.info(f'loaded checkpoint from {ckpt_path}')
    
            activation_path = os.path.join(args.checkpoint_path, 
                                                '{}{}/activations_seed{}.pt'.format(args.save_unit, args.infer_steps[0], args.seed))
            args.logger.info('Loading activations from {}'.format(activation_path))
            checkpoint = torch.load(activation_path)
            embeds = checkpoint['embeds']
            preds = checkpoint['preds']
            outputs = checkpoint['outputs']
            labels = checkpoint['labels']
        else:
            embeds, preds, outputs, labels = save_activations(model, train_val_loader, args)

        if args.infer == 'eiil':
            cluster_labels = eiil(outputs, labels, args)
        elif args.infer == 'misclass':
            args.logger.info('Using misclassification inference')
            cluster_labels = misclass(preds, labels, args)
        else:
            raise NotImplementedError

        # save cluster labels
        cluster_file = os.path.join(args.save_dir, 'cluster_labels.pt')
        torch.save(cluster_labels, cluster_file)
    
    if args.visualize:
        # initialize figure that fits 4 rows and 10 columns
        fig = plt.figure(figsize=(10, 6))
        for c in range(args.num_classes):
            class_cluster = cluster_labels[np.where(np.array(labels)==c)[0]]
            class_indices = np.where(np.array(labels)==c)[0]
            
            num_examples = 10

            # sort clusters by count
            unique_clusters, cluster_counts = np.unique(class_cluster, return_counts=True)
            # rank clusters by count
            cluster_rank = np.argsort(cluster_counts)[::-1]
            # sort clusters by rank
            unique_clusters = unique_clusters[cluster_rank]
            
            num_clusters = len(np.unique(class_cluster))
            # for each cluster
            for cluster_idx, cluster in enumerate(unique_clusters):
                cluster_indices = class_indices[class_cluster == cluster]
                cluster_acc = np.mean(preds[cluster_labels == cluster].argmax(axis=1) == labels[cluster_labels == cluster])
                # plot 5 examples
                for i in range(min(num_examples, len(cluster_indices))):
                    img_path = train_dataset.dataset.all_data[cluster_indices[i]][0]
                    img = Image.open(img_path).convert('RGB')
                    ax = fig.add_subplot(num_clusters*args.num_classes, num_examples, cluster_idx*num_examples + i + 1 + c*num_clusters*num_examples)
                    ax.imshow(img)
                    ax.axis('off')
                    # label the cluster accuarcy and count
                    if i == 0:
                        ax.set_title('cluster acc: {:.2f} \n count: {}'.format(cluster_acc, len(cluster_indices)))
        plt.tight_layout()
        os.makedirs(os.path.dirname("{}/cluster_examples_{}{}.png".format(args.save_dir, args.save_unit, '-'.join([str(step) for step in steps]), c)), exist_ok=True)
        plt.savefig("{}/cluster_examples_{}{}.png".format(args.save_dir, args.save_unit, '-'.join([str(step) for step in steps]), c), dpi=50)

    if args.dataset in ['cifar100sup', 'cifar10', 'imagenet']:
        # for each class, plot 5 examples in each cluster
        if args.num_clusters == 2:
            num_examples = 10
        else:
            num_examples = 5
        for c in range(args.num_classes):
            # initialize figure
            fig = plt.figure(figsize=(10, 10))

            class_cluster = cluster_labels[np.where(np.array(labels)==c)[0]]
            class_indices = np.where(np.array(labels)==c)[0]

            # sort clusters by count
            unique_clusters, cluster_counts = np.unique(class_cluster, return_counts=True)
            # rank clusters by count
            cluster_rank = np.argsort(cluster_counts)[::-1]
            # sort clusters by rank
            unique_clusters = unique_clusters[cluster_rank]
            
            num_clusters = len(np.unique(class_cluster))
            # for each cluster
            for cluster_idx, cluster in enumerate(unique_clusters):
                cluster_indices = class_indices[class_cluster == cluster]
                cluster_acc = np.mean(preds[cluster_labels == cluster].argmax(axis=1) == labels[cluster_labels == cluster])
                # plot 5 examples
                for i in range(min(num_examples, len(cluster_indices))):
                    if args.dataset == 'imagenet':
                        img_path = train_dataset.dataset.samples[cluster_indices[i]][0]
                        img = train_dataset.dataset.loader(img_path)
                        # resize to 224x224
                        img = img.resize((224, 224))
                    else:
                        img = train_dataset.dataset.data[cluster_indices[i]]
                    ax = fig.add_subplot(num_clusters*num_examples//5, 5, cluster_idx*num_examples + i + 1)
                    ax.imshow(img)
                    ax.axis('off')
                    # label the cluster accuarcy and count
                    if i == 0:
                        ax.set_title('cluster acc: {:.2f} \n count: {}'.format(cluster_acc, len(cluster_indices)))
            plt.tight_layout()
            os.makedirs(os.path.dirname("{}/{}clusters_examples_{}{}/class{}.png".format(args.save_dir, args.num_clusters, args.save_unit, '-'.join([str(step) for step in steps]), c)), exist_ok=True)
            plt.savefig("{}/{}clusters_examples_{}{}/class{}.png".format(args.save_dir, args.num_clusters, args.save_unit, '-'.join([str(step) for step in steps]), c), dpi=50)
            # plt.savefig("{}/cluster_examples_{}{}/class{}.pdf".format(args.save_dir, args.save_unit, '-'.join([str(step) for step in steps]), c), dpi=50)
    elif 'cmnist' not in args.dataset:
        # plot confusion matrix
        fig, cluster_labels = plot_confusion_matrix(train_dataset.group_array, cluster_labels, args)
        # plt.savefig("{}/confusion_train_{}{}.pdf".format(args.save_dir, args.save_unit, '-'.join([str(step) for step in steps])), dpi=50)
        plt.savefig("{}/confusion_train_{}{}.png".format(args.save_dir, args.save_unit, '-'.join([str(step) for step in steps])), dpi=50)
        plt.close()

    # compute accuracy for each cluster with preds and cluster labels
    for group in np.unique(cluster_labels):
        # print cluster size
        args.logger.info('Cluster {} size: {}'.format(group, np.sum(cluster_labels == group)))
        if args.infer != 'ssa':
            args.logger.info('Cluster {} accuracy at inference: {:.2f}'.format(group, 
                np.mean(preds[cluster_labels == group].argmax(axis=1) == labels[cluster_labels == group])))
        
    # compute number of correct predictions for each cluster and group combination
    log_dict = {}
    for group in np.unique(train_dataset.group_array):
        if args.infer != 'ssa':
            args.logger.info('Group {} accuracy at inference: {:.2f}'.format(group, 
                np.mean(preds[train_dataset.group_array == group].argmax(axis=1) == labels[train_dataset.group_array == group])))
        group_indices = np.where(train_dataset.group_array == group)[0]
        group_cluster_labels = cluster_labels[group_indices]
        for cluster in np.unique(group_cluster_labels):
            group_cluster_indices = group_indices[group_cluster_labels == cluster]
            if args.infer != 'ssa':
                class_preds = preds[group_cluster_indices].argmax(axis=1)
                args.logger.info('Group {} Cluster {} accuracy at inference: {:d}/{:d} ({:.2f})'.format(group, cluster, 
                    np.sum(class_preds == labels[group_cluster_indices]), len(class_preds), 
                    np.mean(class_preds == labels[group_cluster_indices])))
                log_dict['group_{}/cluster_{}_acc'.format(group, cluster)] = np.mean(class_preds == labels[group_cluster_indices])
            else:
                # only print the size of the cluster
                args.logger.info('Group {} Cluster {} size: {}'.format(group, cluster, len(group_cluster_indices)))

        log_dict['group_{}_acc'.format(group)] = np.mean(preds[group_indices].argmax(axis=1) == labels[group_indices])
        if len(np.unique(group_cluster_labels)) > 1 and 'cmnist' not in args.dataset:
            log_dict['group_{}_acc_diff'.format(group)] = log_dict['group_{}/cluster_{}_acc'.format(group, cluster-1)] - log_dict['group_{}/cluster_{}_acc'.format(group, cluster)]

    if args.use_wandb:
        if len(steps) > 1:
            log_dict['step'] = steps[-1]
        wandb.log(log_dict)

    if test_loader is not None:
        return cluster_labels, test_loader
    return cluster_labels


def visualize_val(args, val_loader, model, steps):
    # get the predictions of the validation set at the last checkpoint
    val_groups = []
    val_embeds = []
    val_preds = []
    val_outputs = []
    val_labels = []
    # record the highest confidence of each group
    val_conf = np.zeros(len(val_loader))
    for grouo_idx, group_loader in enumerate(val_loader):
        embeds_, preds_, outputs_, labels_  = save_activations(model, group_loader, args)
        val_groups.append(np.ones(preds_.shape[0], dtype=int) * grouo_idx)
        val_embeds.append(embeds_)
        val_preds.append(preds_)
        val_outputs.append(outputs_)
        val_labels.append(labels_)

        # print accuracy for each group
        args.logger.info('Group {} accuracy: {:.2f}'.format(grouo_idx, 
            np.mean(np.argmax(preds_, axis=1) == labels_)))
        
        # rank the confidence of each group and record the 75th percentile confidence
        val_conf[grouo_idx] = np.percentile(preds_[np.arange(preds_.shape[0]), labels_], 75)

    val_groups = np.concatenate(val_groups)
    val_embeds = np.concatenate(val_embeds)
    val_preds = np.concatenate(val_preds)
    val_outputs = np.concatenate(val_outputs)
    val_labels = np.concatenate(val_labels)
    colors = np.array(val_groups)
    num_colors = len(np.unique(colors))
    plt.scatter(val_preds[:, 0], val_preds[:, 1], c=colors, s=1.0,
                cmap=plt.cm.get_cmap(args.cmap, num_colors))
    plt.colorbar(ticks=np.unique(colors))
    fpath = os.path.join(args.save_dir,
                            f'pred-init_val.png')
    plt.savefig(fname=fpath, dpi=300, bbox_inches="tight")
    plt.close()
    args.logger.info(f'Saved val visualization to {fpath}!')

    if args.cluster_metric == 'pred':
        val_cluster = clustering(args, val_groups, val_preds, val_labels)
    elif args.cluster_metric == 'embed':
        val_cluster = clustering(args, val_groups, val_embeds, val_labels)
    elif args.cluster_metric == 'logit':
        val_cluster = clustering(args, val_groups, val_outputs, val_labels)
    elif args.cluster_metric == 'grad':
        g0 = val_preds - np.eye(val_preds.shape[1])[val_labels]
        g0_expand = np.repeat(g0, np.shape(val_embeds)[1], axis=1)
        val_grads = g0_expand * np.tile(val_embeds, np.shape(val_preds)[1])
        val_cluster = clustering(args, val_groups, val_grads, val_labels)

    if args.visualize:
        # plot confusion matrix
        fig, _ = plot_confusion_matrix(val_groups, val_cluster, args)
        plt.savefig("{}/confusion_val_{}{}.png".format(args.save_dir, args.save_unit, '-'.join([str(step) for step in steps])), dpi=50)
        plt.close()

    return val_groups, val_embeds, val_preds, val_outputs, val_labels, val_cluster, val_conf