import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment as linear_assignment


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def all_sum_item(item):
    item = torch.tensor(item).cuda()
    dist.all_reduce(item)
    return item.item()

def cluster_acc_(y_true, y_pred):
    """
    Calculate clustering accuracy. Require scikit-learn installed
    First compute linear assignment on all data, then look at how good the accuracy is on subsets

    # Arguments
        mask: Which instances come from old classes (True) and which ones come from new classes (False)
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`

    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(int)

    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=int)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1

    ind = linear_assignment(w.max() - w)
    ind = np.vstack(ind).T

    ind_map = {j: i for i, j in ind}
    total_acc = sum([w[i, j] for i, j in ind])
    total_instances = y_pred.size
    try: 
        if dist.get_world_size() > 0:
            total_acc = all_sum_item(total_acc)
            total_instances = all_sum_item(total_instances)
    except:
        pass
    total_acc /= total_instances

    return total_acc


def cal_acc_cluster(loader, netR):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = next(iter_test)
            inputs = data[0][0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netR(inputs)['cluster_logits']
            if start_test:
                all_output = outputs.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    _, predict = torch.max(all_output, 1)
    return cluster_acc_(all_label.numpy(), predict.numpy())


def cal_acc_both(loader, type, netF, netB, netC, netR, is_visda=False):
    start_test = True
    with torch.no_grad():
        iter_test = iter(loader)
        for i in range(len(loader)):
            data = next(iter_test)
            if type == 'target':
                inputs = data[0][0]
            elif type == 'test':
                inputs = data[0]
            labels = data[1]
            inputs = inputs.cuda()
            outputs = netC(netB(netF(inputs)))
            if netR is not None:
                outputsR = netR(inputs)['cluster_logits']
            if start_test:
                all_output = outputs.float().cpu()
                if netR is not None:
                    all_outputsR = outputsR.float().cpu()
                all_label = labels.float()
                start_test = False
            else:
                all_output = torch.cat((all_output, outputs.float().cpu()), 0)
                if netR is not None:
                    all_outputsR = torch.cat((all_outputsR, outputsR.float().cpu()), 0)
                all_label = torch.cat((all_label, labels.float()), 0)
    _, predict = torch.max(all_output, 1)

    if is_visda:
        matrix = confusion_matrix(all_label, torch.squeeze(predict).float())
        acc = matrix.diagonal()/matrix.sum(axis=1) * 100
        aacc = acc.mean()
        aa = [str(np.round(i, 2)) for i in acc]
        acc = ' '.join(aa)
        acc_F = (aacc, acc)
    else:
        acc_F = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(
            all_label.size()[0]
        )
    
    if netR is not None:
        _, predictR = torch.max(all_outputsR, 1)
        acc_R = cluster_acc_(all_label.numpy(), predictR.numpy())
    else:
        acc_R = None

    return acc_F, acc_R


def cal_hardness_acc(h_scores, preds, gts, law="power", num_partitions=20, alpha=6.0, is_visda=True):

    def split_array(A, B, r):
        assert len(A) == len(B)
        split_sizes = (r * len(A)).astype(int)
        # Ensure that the sum of split sizes equals the length of A due to potential rounding issues
        split_sizes[-1] = len(A) - np.sum(split_sizes[:-1])
        # Calculate the indices at which to split the array
        split_indices = np.cumsum(split_sizes)[:-1]
        # Split the array
        splits_A = np.split(A, split_indices)
        splits_B = np.split(B, split_indices)
        return splits_A, splits_B
    
    def calc_mean_cls_acc(predictions, ground_truth):
        # Get unique classes
        classes = np.unique(ground_truth)
        
        # Initialize a list to store the accuracy for each class
        class_accuracies = []
        
        # Calculate accuracy for each class
        for cls in classes:
            # Get indices of the current class
            class_indices = np.where(ground_truth == cls)
            
            # Calculate the number of correct predictions for the current class
            correct_predictions = np.sum(predictions[class_indices])
            
            # Calculate the total number of samples for the current class
            total_samples = len(class_indices[0])
            
            # Calculate accuracy for the current class and append to the list
            class_accuracy = correct_predictions / total_samples
            class_accuracies.append(class_accuracy)
        
        # Calculate and return the mean class accuracy
        mean_accuracy = np.mean(class_accuracies)
        return mean_accuracy
    
    # 0. sort by h_scores
    # 1. get split ratio
    if law == "power":
        sort_indices = np.argsort(1 - h_scores)
        power = np.array([float(alpha) / num_partitions * i for i in range(num_partitions)])
        ratio = np.exp(power)
    elif law == "zipfs":
        sort_indices = np.argsort(h_scores)
        ratio = 1.0 / np.array([i+1. for i in range(num_partitions)])
    else:
        raise NotImplementedError
    
    ratio = ratio / ratio.sum()
    h_scores = h_scores[sort_indices]
    gts = gts[sort_indices]
    preds = preds[sort_indices]
    # print(ratio * len(preds))

    # 2. partition dataset and calc acc
    preds_splits, gts_splits = split_array(preds, gts, ratio)
    accs = np.array([calc_mean_cls_acc(preds_split, gts_split) for preds_split, gts_split in zip(preds_splits, gts_splits)])

    np.save(open("aad_ours_rn101_hacc_pow.npy", "wb"), accs)

    return accs.mean()


class DistillLoss(nn.Module):
    def __init__(self, warmup_teacher_temp_epochs, nepochs, 
                 ncrops=2, warmup_teacher_temp=0.07, teacher_temp=0.04,
                 student_temp=0.1):
        super().__init__()
        self.student_temp = student_temp
        self.ncrops = ncrops
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax(teacher_output / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        return total_loss


def info_nce_logits(features, n_views=2, temperature=1.0, device='cuda'):
    b_ = 0.5 * int(features.size(0))

    labels = torch.cat([torch.arange(b_) for i in range(n_views)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(device)

    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)

    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

    logits = logits / temperature
    return logits, labels


def get_mat(x1, topk=5):
    rank_feat = x1.detach()
    bsz = rank_feat.size(0)
    # top-k rank statics - determine if two samples in T have the same label
    rank_idx = torch.argsort(rank_feat, dim=1, descending=True)[:, :topk]
    rank_row, rank_col = rank_idx[None,:,:topk], rank_idx[:,None,:topk]
    rank_row, _ = torch.sort(rank_row, dim=-1)
    rank_col, _ = torch.sort(rank_col, dim=-1)
    rank_row = rank_row.expand(bsz, bsz, topk)
    rank_col = rank_col.expand(bsz, bsz, topk)
    rank_diff = torch.sum(torch.abs(rank_row - rank_col), dim=-1)
    target_ulb = torch.ones_like(rank_diff).float().cuda()
    target_ulb[rank_diff > 0] = 0
        
    return target_ulb


"""
Code taken from ---
Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR
"""
class SupConLoss(nn.Module):

    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None, type='feat'):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
        
        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        mask = mask.flatten()
        logits = logits.flatten()
        logits_mask = logits_mask.flatten()

        logits = logits[mask > -1]
        logits_mask = logits_mask[mask > -1]
        mask = mask[mask > -1]

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        # log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        log_prob = logits - torch.log(exp_logits.sum())

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point. 
        # Edge case e.g.:- 
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan]
        mask_pos_pairs = mask.sum()
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum() / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        # loss = loss.view(anchor_count, batch_size).mean()

        return loss

    
cluster_list, cls_list = [], []
def get_ulb_sim_matrix_v2(mode, preds_t, diff_ratio, sim_ratio, 
                          sim_threshold=0, max_sim_list=30, list='cluster'):
    if mode == 'sim':
        preds_t = F.normalize(preds_t, dim=1)
    elif mode == 'prob':
        preds_t = F.softmax(preds_t / 0.1, dim=1) if list=='cluster' else F.softmax(preds_t, dim=1)

    similarity = preds_t @ preds_t.T
    
    global cluster_list, cls_list
    if list == 'cluster':
        cluster_list.append(similarity.flatten())
        if len(cluster_list) > max_sim_list:
            cluster_list = cluster_list[1:]
        sim_all = torch.cat(cluster_list, dim=0)
    elif list == 'class':
        cls_list.append(similarity.flatten())
        if len(cls_list) > max_sim_list:
            cls_list = cls_list[1:]
        sim_all = torch.cat(cls_list, dim=0)
    sim_all_sorted, _ = torch.sort(sim_all)

    n_diff = min(len(sim_all) * diff_ratio, len(sim_all)-1)
    n_sim = min(len(sim_all) * sim_ratio, len(sim_all))

    low_threshold = sim_all_sorted[int(n_diff)]
    high_threshold = max(sim_threshold, sim_all_sorted[-int(n_sim)])

    sim_matrix_ulb = -torch.ones_like(similarity).float()
    # sim_matrix_ulb = torch.ones_like(similarity).float() * 0.5
    # sim_matrix_ulb = torch.zeros_like(similarity).float()

    if high_threshold != low_threshold:
        sim_matrix_ulb[similarity >= high_threshold] = 1
        sim_matrix_ulb[similarity <= low_threshold] = 0
    else:
        sim_matrix_ulb[similarity > high_threshold] = 1
        sim_matrix_ulb[similarity < low_threshold] = 0

    # diagonal should be 1
    sim_matrix_ulb.fill_diagonal_(1)

    return sim_matrix_ulb