import torch
import torch.nn as nn
from random import sample
import numpy as np
import torch.nn.functional as F

class Model(nn.Module):

    def __init__(self, args, base_encoder):
        super().__init__()
        
        if args.dataset == 'cub200' or args.dataset == 'flower102':
            pretrained = True
        else:
            pretrained = False
        # we allow pretraining for CUB200, or the network will not converge
        self.args = args
        self.encoder_q = base_encoder(num_class=args.num_class, feat_dim=args.low_dim, name=args.arch, pretrained=pretrained)
        # momentum encoder
        self.encoder_k = base_encoder(num_class=args.num_class, feat_dim=args.low_dim, name=args.arch, pretrained=pretrained)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient


    def forward(self, img_q, im_k=None, Y_ori=None, args=None, eval_only=False, stop_warmup=False):

        output, q = self.encoder_q(img_q)
        if eval_only:
            return output

        with torch.no_grad():  # no gradient 

            predicetd_scores_q = torch.softmax(output, dim=1)
            within_max_cls_conf, within_max_cls = torch.max(predicetd_scores_q * Y_ori, dim=1)
            all_max_cls_conf, all_max_cls = torch.max(predicetd_scores_q, dim=1)

            pseudo_labels_b = within_max_cls
            pseudo_labels_b = pseudo_labels_b.long()

            predicetd_scores_q = predicetd_scores_q * Y_ori 
            predicetd_scores_q = predicetd_scores_q / predicetd_scores_q.sum(dim = 1).repeat(args.num_class, 1).transpose(0, 1)
            predicetd_scores_k = torch.softmax(output, dim=1) * Y_ori
            predicetd_scores_k = predicetd_scores_k / predicetd_scores_k.sum(dim = 1).repeat(args.num_class, 1).transpose(0, 1)
            predicetd_scores = torch.cat((predicetd_scores_q, predicetd_scores_k),dim=0)
 
        return output, predicetd_scores,  all_max_cls_conf



# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output

@torch.no_grad()
def replace_invalid_values(tensor):
    
    tensor[torch.isinf(tensor)] = 0.

    tensor[torch.isnan(tensor)] = 0.

    tensor[tensor < 0] = 0.

    return tensor