import argparse
from tqdm import tqdm
from itertools import cycle
import math
import statistics
import matplotlib.pyplot as plt
plt.style.use('seaborn-dark-palette')
plt.rcParams['font.serif'] = ['Times New Roman']

import numpy as np
from scipy.stats import zscore
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

from utils.optimizer import build_optimizer, get_mean_lr
from utils.utils import AverageMeter, str2bool, init_experiment, set_seed
from utils.cluster_utils import log_accs_from_preds, cluster_acc, KMeansGPU
from utils.logger import setup_logger, func_info
from utils.metrics import accuracy
from model.backbones import vision_transformer as vits
from datasets.augmentations import get_transform
from datasets import make_dataloader, get_class_splits
from loss.supcon_loss import SupConLoss
from loss.cl_loss import info_nce_logits, pcl_loss
from loss.entropy import entropy, average_patch_entropy, mutual_information_loss
from config import cfg

# TODO: Debug
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


class SIPAManager:

    def __init__(self, model, projection_head, img_num1, img_num2, args, cfg):
        self.args = args
        self.cfg = cfg
        self.logger = setup_logger("unincd baseline", cfg.OUTPUT_DIR, if_train=True)
        self.proto_l = None
        self.proto_u = None
        self.centroids = None
        self.pseudo_labels = None
        self.esti_offset = None
        self. pre_proto_update_acc = 0
        self.pre_disti_sample_u = 1e8
        self.model = model
        self.projection_head = projection_head
        self.device = args.device
        self.label_memory1 = torch.zeros((img_num1), dtype=torch.long)  # labels of source data
        self.label_memory2 = torch.zeros((img_num2), dtype=torch.long)  # labels of target data
        self.feat_memory1 = torch.zeros((img_num1, self.args.mlp_out_dim), dtype=torch.float32)
        self.feat_memory2 = torch.zeros((img_num2, self.args.mlp_out_dim), dtype=torch.float32)

    def feature_extract(self, imgs, labeled=True, ignore_logits=False):
        ####  Extract features with base model  ####
        logits = []
        if self.cfg.MODEL.ALL_PATCHES and not ignore_logits:
            feat_all = self.model(imgs, return_all_patches=True) # False
            cls_token, patch_tokens = feat_all[:, 0].unsqueeze(1), feat_all[:, 1:]
            B, D = feat_all.shape[0], feat_all.shape[2]
            H = W = int(math.sqrt(patch_tokens.shape[1]))
            pool_dim = self.cfg.MODEL.POOL_SIZE
            patch_tokens = patch_tokens.view(B, H, W, D).permute(0,3,1,2) 
            # Arrange the tensor to the shape [256, 768, 14, 14]. Because PyTorch expects the spatial dimensions in the last two dimensions
            pool_layer = nn.AdaptiveAvgPool2d((pool_dim, pool_dim))  # Define the pooling layer
            patch_tokens = pool_layer(patch_tokens) # Apply the pooling layer
            # After pooling, the output shape is [256, 768, 4, 4], reshape back to size [256, 16, 768]
            patch_tokens = patch_tokens.permute(0, 2, 3, 1).contiguous().view(B, pool_dim*pool_dim, D)
            # extract feat and logits with projection head
            cls_feat, cls_logits = self.projection_head(cls_token, labeled=labeled)
            feat = cls_feat[:, 0]
            _, patch_logits = self.projection_head(patch_tokens, labeled=labeled)
            logits = torch.cat([cls_logits, patch_logits], dim=1)
        else:
            feat = self.model(imgs, return_all_patches=False) # False
            feat, logits = self.projection_head(feat, labeled=labeled)
        feat = torch.nn.functional.normalize(feat, dim=-1) # L2-normalize features
        return feat, logits

    def alignment(self, km):
        if self.centroids is not None:
            old_centroids = self.centroids.cpu().numpy()
            new_centroids = km.cluster_centers_
            DistanceMatrix = np.linalg.norm(old_centroids[:,np.newaxis,:]-new_centroids[np.newaxis,:,:],axis=2) 
            _, col_ind = linear_sum_assignment(DistanceMatrix)
            alignment_labels = list(col_ind)
            pseudo2label = {label:i for i,label in enumerate(alignment_labels)}
            pseudo_labels = np.array([pseudo2label[label] for label in km.labels_])
        else:
            self.centroids = torch.tensor(km.cluster_centers_, dtype=torch.float32).to(self.device)
            pseudo_labels = km.labels_
        pseudo_labels = torch.tensor(pseudo_labels + self.esti_offset, dtype=torch.long).to(self.device)
        return pseudo_labels

    @func_info
    def proto_l_update(self, epoch):
        # initialize prototypes
        num_known_classes = self.args.num_known_classes
        label_count = np.zeros(num_known_classes)
        if epoch == 1:
            # labeled prototypes
            proto_l = np.zeros((num_known_classes, self.feat_memory1.shape[1]))
            for i in range(self.feat_memory1.shape[0]):
                proto_l[self.label_memory1[i]] += self.feat_memory1[i].numpy()
                label_count[self.label_memory1[i]] += 1
            for i in range(num_known_classes):
                proto_l[i] = proto_l[i] / label_count[i]
            self.proto_l = torch.tensor(proto_l, dtype=torch.float32).to(self.device)
        else:
            proto_l_update = np.zeros((num_known_classes, self.feat_memory1.shape[1]))
            # update labeled prototypes
            for i in range(self.feat_memory1.shape[0]):
                proto_l_update[self.label_memory1[i]] += self.feat_memory1[i].numpy()
                label_count[self.label_memory1[i]] += 1
            for i in range(num_known_classes):
                proto_l_update[i] = proto_l_update[i] / label_count[i]
            # update labeled prototypes
            self.proto_l = self.cfg.MODEL.PROTO_MOMEN * self.proto_l + (1 - self.cfg.MODEL.PROTO_MOMEN) * torch.tensor(proto_l_update, dtype=torch.float32).to(self.device)

    @func_info
    def proto_u_update(self, epoch):
        # initialize prototypes
        num_unknown_classes = self.args.num_unknown_classes

        # Align labeled and unlabeled prototypes
        self.logger.info("Performing k-means...")
        km = KMeansGPU(n_clusters = num_unknown_classes)
        km.fit(self.feat_memory2)
        self.logger.info("K-means finished")
        proto_u = torch.tensor(km.cluster_centers_, dtype=torch.float32).to(self.device)
        distance = cdist(self.proto_l.cpu().numpy(), proto_u.detach().cpu().numpy(), 'euclidean')
        z_score = np.abs(zscore(distance.min(1)))
        quan_thres = np.quantile(z_score, self.cfg.MODEL.PROTO_THRES)
        self.esti_offset = sum(~(z_score < quan_thres))
        print(f'Distance for Each Labeled Prototypes:\n{distance.min(1)}\nZ-score: {zscore(distance.min(1))}\nQuantile Threshold: {quan_thres}' )
        # Update unlabeled prototypes
        _, col_ind = linear_sum_assignment(distance)
        col_ind_masked = col_ind[z_score < quan_thres]
        pro_l = []
        for u_idx in col_ind_masked:
            pro_l.append(proto_u[u_idx][:])
        pro_u = []
        for j in range(num_unknown_classes):
            if j not in col_ind_masked:
                pro_u.append(proto_u[j][:])
        print('Prototypes saved')
        # Obtain pseudo labels from aligned unlabeled prototypes
        pro_u = pro_l + pro_u
        proto_u_aligned = torch.stack(pro_u)
        self.centroids = proto_u_aligned.clone().detach()
        # update cluster_ids for unlabeled data
        pseudo_labels_update = self.alignment(km)
        km_acc = cluster_acc(km.labels_ + self.esti_offset, self.label_memory2.numpy())
        acc_update = accuracy(pseudo_labels_update.cpu().numpy(), self.label_memory2.numpy())
        print('Pseudo labels accuracy: KMeans {} | Prev Best {} | Updated {}'.format(km_acc, self.pre_proto_update_acc, acc_update))
        _, disti_sample_u =  self.eval_proto(pseudo_labels_update.cpu().numpy() - self.esti_offset)

        if disti_sample_u < self.pre_disti_sample_u:
            # update unlabeled prototypes
            self.logger.info(f'Find Better Unlabeled Prototypes (Offset: {self.esti_offset}) ....')
            self.proto_u = proto_u_aligned
            self.pseudo_labels = pseudo_labels_update
            self.pre_disti_sample_u = disti_sample_u
            self.pre_proto_update_acc = acc_update
        else:
            self.logger.info('Keep the Previous Best Prototypes Unchanged....')
        print('Pseudo Labels: {}'.format(self.pseudo_labels))

    def eval_proto(self, pseudo_labels):
        proto_u = self.centroids.cpu()
        # calculate the distance from each feature to their nearest prototypes (sum and mean)
        dist_l = cdist(self.feat_memory1, self.proto_l.cpu())
        dist_u = cdist(self.feat_memory2, proto_u)
        acc_l = (dist_l.argmin(axis=1) == self.label_memory1.cpu().numpy()).mean()
        acc_u = (dist_u.argmin(axis=1) + self.esti_offset == self.label_memory2.cpu().numpy()).mean()

        proto_2_ins = {i: [] for i in range(len(proto_u))}
        proto_rs_dist = {i: [] for i in range(len(proto_u))}
        for idx, pl in enumerate(pseudo_labels): proto_2_ins[pl].append(idx)
        # find prototypes and their corresponding instances index
        for pro_idx, pro_u in enumerate(proto_u):
            pro_u = pro_u.unsqueeze(0)
            # Calculate pairwise distances between instance and proto_u using Euclidean distance
            p2p_dist = cdist(pro_u.numpy(), torch.cat((proto_u[:1], proto_u[1+1:]), dim=0).numpy(), 'euclidean')
            p2p_nearest_indices_top3 = set(p2p_dist.argsort()[0][:3]) # Sort the indices based on distances
            for ins_idx in proto_2_ins[pro_idx]:
                i2p_dist = cdist(self.feat_memory2[ins_idx].unsqueeze(0), proto_u)
                i2p_nearest_indices = i2p_dist.argsort()[0] # Sort the indices based on distances
                if len(set(i2p_nearest_indices[:3]).intersection(p2p_nearest_indices_top3)) > 2:
                    proto_rs_dist[pro_idx].append(cdist(self.feat_memory2[ins_idx].unsqueeze(0), pro_u.cpu())[0][0])
            if len(proto_rs_dist[pro_idx]) < int(len(proto_2_ins[1]) * 1e-1):
                proto_rs_dist[pro_idx] = dist_u.min(axis=1)[proto_2_ins[pro_idx]]
        rs_res = np.array([[statistics.mean(values), len(values)] for _, values in proto_rs_dist.items()])

        # Log res
        self.logger.info('[Instacne Distance Mean]\n\tLabeled: {:.4f} | Unlabeled: {:.4f}'.\
            format(dist_l.min(axis=1).mean(), dist_u.min(axis=1).mean()))
        self.logger.info('[RS Distance Mean]\n\t: {} (Mean: {:.4f})\n[RS Distance Count]: {}'.\
            format(rs_res[:, 0], rs_res[:, 0].mean(), rs_res[:, 1]))
        self.logger.info('[Protypical Accuracy] Labeled: {:.4f} | Unlabeled: {:.4f}'.format(acc_l, acc_u))

        return dist_l.min(axis=1).mean(), dist_u.min(axis=1).mean()

    def update_feat(self, train_loader1, train_loader2, epoch=0):
        self.model.eval()
        self.projection_head.eval()
        for n_iter, (img, vid, idx) in enumerate(tqdm(train_loader1, desc='Updating Features for Labelled Dataset')):
            with torch.no_grad():
                img = img.to(self.device)
                feat, _ = self.feature_extract(img, ignore_logits=True)
                if epoch == 1:
                    self.feat_memory1[idx], self.label_memory1[idx] = feat.detach().cpu(), vid
                else:
                    self.feat_memory1[idx] = self.feat_memory1[idx] * cfg.MODEL.FEAT_MOMEN  + feat.detach().cpu() * (1-cfg.MODEL.FEAT_MOMEN)
        for n_iter, (img, vid, idx) in enumerate(tqdm(train_loader2, desc='Updating Features for Unlabelled Dataset')):
            with torch.no_grad():
                img = img.to(self.device)
                feat, _ = self.feature_extract(img, labeled=False, ignore_logits=True)
                if epoch == 1:
                    self.feat_memory2[idx], self.label_memory2[idx] = feat.detach().cpu(), vid
                else:
                    self.feat_memory2[idx] = self.feat_memory2[idx] * cfg.MODEL.FEAT_MOMEN  + feat.detach().cpu() * (1-cfg.MODEL.FEAT_MOMEN)

    def train(self,
            train_loader_labeled,
            train_loader_unlabeled,
            test_loader,
            ):

        args=self.args
        optimizer = build_optimizer(cfg, list(self.projection_head.parameters()) + list(self.model.parameters()))
        exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=cfg.SOLVER.MAX_EPOCHS,
            eta_min=0.002 * cfg.SOLVER.BASE_LR,
        )
        best_test_acc_lab = 0

        self.update_feat(train_loader_labeled, train_loader_unlabeled, epoch=1)
        self.proto_l_update(epoch=1)
        self.proto_u_update(epoch=1)
        for epoch in range(1, cfg.SOLVER.MAX_EPOCHS + 1):
            pun_loss_record, pla_loss_record, se_loss_record, ce_l_loss_record, ce_u_loss_record, loss_record = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
            ce_acc_l_record, ce_acc_u_record = AverageMeter(), AverageMeter()
            self.projection_head.train()
            self.model.train()
            if epoch % cfg.SOLVER.UPDATE_MAX == 0:
                self.update_feat(train_loader_labeled, train_loader_unlabeled, epoch)
                self.proto_l_update(epoch)
            if epoch <= cfg.SOLVER.UPDATE_UP_MAX:
                self.update_feat(train_loader_labeled, train_loader_unlabeled, epoch)
                self.proto_l_update(epoch)
                self.proto_u_update(epoch)

            train_loader_unlabeled_iter = cycle(train_loader_unlabeled)
            for b_idx, batch_l in enumerate(tqdm(train_loader_labeled, desc='SIPA Training')):
                batch_u = next(train_loader_unlabeled_iter)
                imgs_l, labels_l, _ = batch_l
                imgs_u, labels_u, idxs_u = batch_u
                labels_l, labels_u = labels_l.to(self.device), labels_u.to(self.device)
                mask_l, mask_u = labels_l.cpu() >= self.esti_offset, self.pseudo_labels.cpu()[idxs_u] < self.args.num_known_classes
                feat_l, logits_l_all_p = self.feature_extract(imgs_l.to(self.device))
                feat_u, logits_u_all_p = self.feature_extract(imgs_u.to(self.device), labeled=False)
                if b_idx == np.random.randint(len(train_loader_labeled)) and epoch == 1: 
                    self.vis_img(imgs_l[0], epoch=epoch, num_cols=self.cfg.MODEL.POOL_SIZE, num_rows=self.cfg.MODEL.POOL_SIZE, label=str(labels_l[4]))
                if self.cfg.MODEL.ALL_PATCHES:
                    logits_l, logits_u = logits_l_all_p[:, 0], logits_u_all_p[:, 0]
                else:
                    logits_l, logits_u = logits_l_all_p, logits_u_all_p
                del imgs_l, imgs_u

                #------------------------------------
                # Compute loss with semantic similarity
                # ------------------------------------
                ####  Extract features For unlabeled samples + common labeled samples  ####
                features_cu = torch.cat([feat_l[mask_l], feat_u])
                loss_pro_u = pcl_loss(features_cu, self.proto_u, args.temperature)
                pun_loss_record.update(loss_pro_u.item(), features_cu.size(0))
                del features_cu
                ####  Extract features For labeled samples + common unlabeled samples  ####
                features_cl = torch.cat([feat_l, feat_u[mask_u]])
                loss_pro_l = pcl_loss(features_cl, self.proto_l, args.temperature)
                pla_loss_record.update(self.cfg.LOSS.W_PRO_L * loss_pro_l.item(), features_cl.size(0))
                del features_cl

                loss_pro = loss_pro_u + self.cfg.LOSS.W_PRO_L * loss_pro_l

                #------------------------------------
                # Entropy-based Uniformity Regularization Loss (labeled + unlabeled or (lb + un_com) + (un + lb_com))
                # ------------------------------------
                if self.cfg.MODEL.ALL_PATCHES:
                    m = nn.Softmax(dim=2)
                    prob_l_all_p, prob_u_all_p = m(logits_l_all_p[:, 1:, :]), m(logits_u_all_p[:, 1:, :])
                    if b_idx == 0: self.vis_patches(prob_l_all_p[0], epoch)
                    loss_eur = average_patch_entropy(prob_l_all_p, sdr=cfg.LOSS.SDR, sdr_kl_w=cfg.LOSS.SDR_KL_W, kldu_w=cfg.LOSS.KLDU_W, batch_wise=True) \
                        + average_patch_entropy(prob_u_all_p, sdr=cfg.LOSS.SDR, sdr_kl_w=cfg.LOSS.SDR_KL_W, kldu_w=cfg.LOSS.KLDU_W, batch_wise=True)
                else:
                    m = nn.Softmax(dim=1)
                    loss_eur = entropy(m(logits_l)) + entropy(m(logits_u))
                se_loss_record.update(loss_eur * self.cfg.LOSS.W_SE, logits_l_all_p.size(0) + logits_u_all_p.size(0))

                #------------------------------------
                # Cross Entropy Loss (labeled + unlabeled) or Mutual Information Maximum (labeled + unlabeled)
                # ------------------------------------
                loss_ce_l = nn.CrossEntropyLoss()(logits_l, labels_l) if not cfg.LOSS.MIM_L\
                    else mutual_information_loss(logits_l, labels_l, args.num_known_classes, self.device)
                loss_ce_u = nn.CrossEntropyLoss()(logits_u, self.pseudo_labels[idxs_u]-self.esti_offset) if not cfg.LOSS.MIM_U \
                    else mutual_information_loss(logits_u, self.pseudo_labels[idxs_u]-self.esti_offset, args.num_unknown_classes, self.device)
                ce_l_loss_record.update(loss_ce_l, logits_l.size(0))
                ce_u_loss_record.update(loss_ce_u * self.cfg.LOSS.W_UCE, logits_u.size(0))
                _, pred_l = torch.max(logits_l, 1); _, pred_u = torch.max(logits_u, 1)
                ce_acc_l = accuracy(pred_l.cpu().numpy(), labels_l.cpu().numpy())
                ce_acc_u = accuracy(pred_u.cpu().numpy() + self.esti_offset, labels_u.cpu().numpy())
                ce_acc_l_record.update(ce_acc_l, pred_l.size(0))
                ce_acc_u_record.update(ce_acc_u, pred_u.size(0))

                loss = self.cfg.LOSS.W_PRO * loss_pro + loss_eur * self.cfg.LOSS.W_SE + (loss_ce_l + loss_ce_u * self.cfg.LOSS.W_UCE)
                loss_record.update(loss, labels_l.size(0) + labels_u.size(0))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            print(f"\n\n{'='*30} Train Epoch: {epoch} {'='*30}")
            print('PRO_L Loss: {:.4f} | PRO_U Loss: {:.4f} | SE Loss: {:.4f} | CE (L) Loss: {:.4f} | CE (U) Loss: {:.4f} | Total Loss: {:.4f}'.\
                format(pla_loss_record.avg, pun_loss_record.avg, se_loss_record.avg, ce_l_loss_record.avg, ce_u_loss_record.avg, loss_record.avg, ce_acc_l, ce_acc_u))

            with torch.no_grad():
                print('Testing on unlabelled examples in the training data...')
                all_acc, old_acc, new_acc = self.test_kmeans(
                    train_loader_unlabeled, epoch=epoch, save_name='Train ACC Unlabelled'
                )
                print('Testing on disjoint test set...')
                all_acc_test, old_acc_test, new_acc_test = self.test_kmeans(
                    test_loader, epoch=epoch, save_name='Test ACC'
                )

            # ----------------
            # LOG
            # ----------------
            args.writer.add_scalar('Loss', loss_record.avg, epoch)
            args.writer.add_scalar('LR', get_mean_lr(optimizer), epoch)
            print('| CE (L) Acc: {:.4f} | CE (U) Acc: {:.4f}'.format(ce_acc_l_record.avg, ce_acc_u_record.avg))
            print('Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc, new_acc))
            print('Test Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc_test, old_acc_test, new_acc_test))

            # Step schedule
            exp_lr_scheduler.step()

            torch.save(self.model.state_dict(), args.model_path)
            print("model saved to {}.".format(args.model_path))
            torch.save(self.projection_head.state_dict(), args.model_path[:-3] + '_proj_head.pt')
            print("projection head saved to {}.".format(args.model_path[:-3] + '_proj_head.pt'))

            if all_acc_test > best_test_acc_lab:
                print(f'Best ACC on old Classes on disjoint test set: {all_acc_test:.4f}...')
                print('Best Train Accuracies: All {:.4f} | Old {:.4f} | New {:.4f}'.format(all_acc, old_acc,
                                                                                    new_acc))

                torch.save(self.model.state_dict(), args.model_path[:-3] + f'_best.pt')
                print("model saved to {}.".format(args.model_path[:-3] + f'_best.pt'))
                torch.save(self.projection_head.state_dict(), args.model_path[:-3] + f'_proj_head_best.pt')
                print("projection head saved to {}.".format(args.model_path[:-3] + f'_proj_head_best.pt'))

                best_test_acc_lab = all_acc_test

    def test_kmeans(self, test_loader, epoch, save_name):

        self.model.eval()
        self.projection_head.eval()

        all_feats = []
        targets = np.array([])
        mask = np.array([])

        # First extract all features
        for batch_idx, batch in enumerate(tqdm(test_loader, desc=save_name)):
            (images, label, _) = batch
            images = images.cuda()
            feat, _ = self.feature_extract(images, ignore_logits=True)
            all_feats.append(feat)
            targets = np.append(targets, label.cpu().numpy())
            mask = np.append(mask, np.array([True if x.item() in range(self.args.num_known_classes)
                                            else False for x in label]))

        # -----------------------
        # K-MEANS
        # -----------------------
        all_feats = torch.cat(all_feats, dim=0)
        # num_cluster = self.args.num_all_classes
        num_cluster = self.args.num_unknown_classes if 'Train' in save_name else self.args.num_all_classes
        km = KMeansGPU(n_clusters = num_cluster, kmeans_max_iter=self.cfg.SOLVER.KMEANS_MAX_ITER)
        km.fit(all_feats)
        preds = km.labels_ + self.esti_offset if 'Train' in save_name else km.labels_
        all_acc, old_acc, new_acc = log_accs_from_preds(
            y_true=targets, y_pred=preds, mask=mask,
            T=epoch, eval_funcs=self.args.eval_funcs, save_name=save_name,
            writer=self.args.writer, print_output=True
        )

        return all_acc, old_acc, new_acc
    
    def vis_img(self, image_array, epoch, num_rows=4, num_cols=4, label=''):
        if image_array.shape[0] == 3:
            image_array = image_array.permute(1,2,0).cpu().numpy()
        image_array = ((image_array - image_array.min()) / (image_array.max() - image_array.min())) * 254 + 1
        image_array = image_array.astype(np.uint8)
        patch_height = image_array.shape[0] // num_rows
        patch_width = image_array.shape[1] // num_cols
        fig, axes = plt.subplots(ncols=num_cols, nrows=num_rows, figsize=(20,20))
        patches = []
        # Loop through the rows and columns to extract and visualize each patch
        for i in range(num_rows):
            for j in range(num_cols):
                patch = image_array[i * patch_height : (i + 1) * patch_height, j * patch_width : (j + 1) * patch_width]
                patches.append(patch)
                axes[i][j].imshow(patch)
                axes[i][j].axis('off')
        plt.suptitle(f"label id: {label}", fontsize=patch_height)
        self.args.writer.add_figure(f'Input Image', fig, global_step=epoch)
        # plt.tight_layout()
        plt.axis('off')  # Turn off axis labels and ticks
        plt.close(fig)

    def vis_patches(self, prob, epoch):
        prob = prob.cpu().detach().numpy()
        fig, axs = plt.subplots(prob.shape[0], 1, figsize=(6, 2 * prob.shape[0]))
        prototypes = list(range(prob.shape[1]))
        for i in range(prob.shape[0]):
            max_prob_index = np.argmax(prob[i]) # Find the index of the bar with the highest probability
            # Create a list of colors for the bars, with the highest probability bar highlighted
            colors = ['sienna' if i != max_prob_index else 'gold' for i in range(len(prob[i]))]
            # axs[i].hist(prob[i], bins=50, density=True, align='left', color=colors, width=-0.6, edgecolor='black')
            axs[i].bar(prototypes, prob[i], align='edge', color=colors, width=-0.6, edgecolor='black')
            axs[i].set_title(f'Patch {i + 1}')
            axs[i].set_xlabel('Prototypes'); axs[i].set_ylabel('Prob')
            axs[i].set_xticklabels([]); axs[i].set_yticklabels([])

        self.args.writer.add_figure(f'Patches', fig, global_step=epoch)
        plt.close(fig)

if __name__ == "__main__":

    parser = argparse.ArgumentParser(
            description='cluster',
            formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "--config_file", default="", help="path to config file", type=str
    )
    parser.add_argument("opts", help="Modify config options using the command-line", default=None,
                        nargs=argparse.REMAINDER)
    parser.add_argument('--eval_funcs', nargs='+', help='Which eval functions to use', default=['v1', 'v2'])
    parser.add_argument('--grad_from_block', type=int, default=11)
    parser.add_argument('--grad_head_layer', type=str, default='')
    parser.add_argument('--transform', type=str, default='imagenet')
    parser.add_argument('--contrast_unlabel_only', type=str2bool, default=False)

    # ----------------------
    # INIT
    # ----------------------
     # Merge args and config
    args = parser.parse_args()
    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    args.dataset_name = cfg.DATASETS.NAMES
    args = get_class_splits(args)
    args.batch_size = cfg.SOLVER.IMS_PER_BATCH
    args.n_views = cfg.DATALOADER.N_VIEWS
    args.temperature = cfg.LOSS.TEMP_INFO
    args.exp_root = cfg.OUTPUT_DIR
    # Load experiments settings
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()
    init_experiment(args, runner_name=['SIPA'])
    print(f"\n\n\n{'='*100}\n{'='*40} Experiments Configuration {'='*40}\n{'='*100}")
    print('Num of GPUs: {}'.format(args.n_gpu))
    print(f'Using evaluation function {args.eval_funcs[0]} to print results')
    set_seed(cfg.SOLVER.SEED)

    # NOTE: Hardcoded image size as we do not finetune the entire ViT model
    args.image_size = 224
    args.feat_dim = 768
    args.num_mlp_layers = 3
    args.mlp_out_dim = 65536
    args.interpolation = 3
    args.crop_pct = 0.875

    # ----------------- ---
    # CONTRASTIVE TRANSFORM
    # --------------------
    train_transforms, test_transforms = get_transform(args.transform, image_size=args.image_size, args=args)
    # train_transforms = ContrastiveLearningViewGenerator(base_transform=train_transforms, n_views=args.n_views)
    global TRAIN_TRANSFORM, TEST_TRANSFORM
    TRAIN_TRANSFORM, TEST_TRANSFORM = train_transforms, test_transforms

    # --------------------
    # DATALOADER
    # --------------------
    train_loader1, train_loader2, test_loader, num_known_classes, num_unknown_classes, num_all_classes, img_num1, img_num2 = \
            make_dataloader(cfg, args, train_transforms=train_transforms, test_transforms=test_transforms)
    args.num_private_classes = len(args.private_classes)
    args.num_known_classes = num_known_classes
    args.num_unknown_classes = num_unknown_classes
    args.num_all_classes = num_all_classes
    
    # ----------------------
    # BASE MODEL
    # ----------------------
    dino_pretrain_path = cfg.MODEL.PRETRAIN_PATH
    pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE
    print("cfg.MODEL.PRETRAIN_PATH: {}\nPretrain Choice: {}".format(dino_pretrain_path, pretrain_choice))
    if cfg.MODEL.NAME == 'dino':
        model = vits.__dict__['vit_base'](img_size=cfg.INPUT.SIZE_CROP, \
            stride_size=cfg.MODEL.STRIDE_SIZE, drop_path_rate=cfg.MODEL.DROP_PATH, block_pattern=cfg.MODEL.BLOCK_PATTERN, cfg=cfg, args=args)
        if pretrain_choice == 'imagenet':
            model.load_param(dino_pretrain_path)
            print('Loading pretrained ImageNet model......from {}'.format(dino_pretrain_path))
        elif pretrain_choice == 'un_pretrain':
            model.load_un_param(dino_pretrain_path)
            print('Loading trans_tune model......from {}'.format(dino_pretrain_path))
        elif pretrain_choice == 'pretrain':
            if dino_pretrain_path == '':
                print('make model without initialization')
            else:
                model.load_param(dino_pretrain_path)
                print('Loading pretrained model......from {}'.format(dino_pretrain_path))

        model.to(args.device)
        if args.n_gpu > 1:
            model = nn.DataParallel(model)
            print(f'Training model on {args.n_gpu} GPUs.')
        # ----------------------
        # HOW MUCH OF BASE MODEL TO FINETUNE
        # ----------------------
        for m in model.parameters():
            m.requires_grad = False
        # Only finetune layers from block 'args.grad_from_block' onwards
        for name, m in model.named_parameters():
            if 'block' in name:
                if args.n_gpu > 1:
                    block_num = int(name.split('.')[2])
                else:
                    block_num = int(name.split('.')[1])
                if block_num >= args.grad_from_block:
                    m.requires_grad = True

        # ----------------------
        # PROJECTION HEAD
        # ----------------------
        projection_head = vits.__dict__['DINOHead'](in_dim=args.feat_dim,
                                out_dim=args.mlp_out_dim, nlayers=args.num_mlp_layers, args=args)
        if cfg.MODEL.PRETRAIN_PROJ_PATH != '':
            projection_head.load_param_finetune(cfg.MODEL.PRETRAIN_PROJ_PATH)
        projection_head.to(args.device)
        if args.n_gpu > 1:
            projection_head = nn.DataParallel(projection_head)
            print(f'Training projection head on {args.n_gpu} GPUs.')
        if args.grad_head_layer != '':
            for m in projection_head.parameters():
                m.requires_grad = False
            for name, m in projection_head.named_parameters():
                layer_name = name.split('.')[1] if args.n_gpu > 1 else name.split('.')[0]
                if layer_name == args.grad_head_layer or 'head' in layer_name:
                    print(f'{name} requires grad back-propagation...')
                    m.requires_grad = True

    else:

        raise NotImplementedError

    # ----------------------
    # TRAIN
    # ----------------------
    trainer = SIPAManager(model, projection_head, img_num1=img_num1, img_num2=img_num2, args=args, cfg=cfg)
    trainer.train(train_loader1, train_loader2, test_loader)