import torch
import torch.nn as nn
import math
import torch.nn.functional as F

from utils.utils import AverageMeter

def resize_pos_embed(posemb, posemb_new, height, width):
    # Rescale the grid of position embeddings when loading from state_dict. Adapted from
    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
    print('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
    ntok_new = posemb_new.shape[1]
    if True:
        posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
        ntok_new -= 1
    else:
        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
    gs_old = int(math.sqrt(len(posemb_grid)))

    print('Position embedding resize to height:{} width: {}'.format(height, width))
    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
    posemb_grid = F.interpolate(posemb_grid, size=(height, width), mode='bilinear')
    # posemb_grid = F.interpolate(posemb_grid, size=(width, height), mode='bilinear')
    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, height * width, -1)
    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
    return posemb

class Algo(nn.Module):

    def __init__(self, cfg=None, args=None):
        super(Algo, self).__init__()
        self.normalizer = lambda x: x / torch.norm(x, p=2, dim=-1, keepdim=True) + 1e-10
        self.momem_proto = cfg.MODEL.PROTO_MOMEN
        self.proto_num = cfg.MODEL.PROTO_NUM
        self.proto = self.normalizer(torch.randn((self.proto_num, 65536), dtype=torch.float)).to(args.device)
        self.label_stat = torch.zeros(self.proto_num, dtype=torch.int)

        self.loss_stat_init()

    def __call__(self, *args, **kwargs):
        if hasattr(self, 'module'):
            return self.module(*args, **kwargs)
        return super(Algo, self).__call__(*args, **kwargs)

    def loss_stat_init(self, ):
        self.ce_losses = AverageMeter('ce_loss', ':.4e')
        self.entropy_losses = AverageMeter('entropy_loss', ':.4e')
        self.cls_losses = AverageMeter('cls_losses', ':.4e')
        self.cl_losses = AverageMeter('cl_losses', ':.4e')
        self.simclr_losses = AverageMeter('simclr_loss', ':.4e')
        self.supcon_losses = AverageMeter('supcon_loss', ':.4e')
        self.semicon_losses = AverageMeter('semicon_loss', ':.4e')
        self.loss_record = AverageMeter('loss', ':.4e')
        self.train_acc_record = AverageMeter('train_acc', ':.4e')


    @torch.no_grad()
    def update_label_stat(self, label):
        self.label_stat += label.bincount(minlength=self.proto_num).to(self.label_stat.device)

    @torch.no_grad()
    def reset_stat(self):
        self.label_stat = torch.zeros(self.proto_num, dtype=torch.int).to(self.label_stat.device)
        self.loss_stat_init()

    @torch.no_grad()
    def sync_prototype(self):
        self.proto = self.proto_tmp

    def update_prototype_lazy(self, feat, label, weight=None, momemt=None):
        if momemt is None:
            momemt = self.momem_proto
        self.proto_tmp = self.proto.clone().detach()
        if weight is None:
            weight = torch.ones_like(label)
        for i, l in enumerate(label):
            alpha = 1 - (1. - momemt) * weight[i]
            self.proto_tmp[l] = self.normalizer(alpha * self.proto_tmp[l] + (1. - alpha) * feat[i])

    def get_last_selfattention(self, x):
        x = self.prepare_tokens(x)
        for i, blk in enumerate(self.blocks):
            if i < len(self.blocks) - 1:
                x = blk(x)
            else:
                # return attention of the last block
                x, attn = blk(x, return_attention=True)
                x = self.norm(x)
                return x, attn

    def get_intermediate_layers(self, x, n=1):
        x = self.prepare_tokens(x)
        # we return the output tokens from the `n` last blocks
        output = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if len(self.blocks) - i <= n:
                output.append(self.norm(x))
        return output

    def load_un_param(self, trained_path):
        param_dict = torch.load(trained_path)
        if 'state_dict' in param_dict:
            param_dict = param_dict['state_dict']
        for k in list(param_dict.keys()):
            # retain only encoder_q up to before the embedding layer
            if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                # remove prefix
                param_dict[k[len("module.encoder_q."):]] = param_dict[k]
            # delete renamed or unused k
            del param_dict[k]
        for i in param_dict:
            if 'fc' in i or 'head' in i:
                continue
            self.state_dict()[i].copy_(param_dict[i])

    def load_param(self, model_path):
        param_dict = torch.load(model_path, map_location='cpu')
        if 'model' in param_dict:
            param_dict = param_dict['model']

        if 'state_dict' in param_dict:
            param_dict = param_dict['state_dict']

        for k, v in param_dict.items():
            if 'head' in k or 'dist' in k:
                continue
            k = k.replace('module.','') if 'module.' in k else k
            if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
                # For old models that I trained prior to conv based patchification
                O, I, H, W = self.patch_embed.proj.weight.shape
                v = v.reshape(O, -1, H, W)
            elif k == 'pos_embed' and v.shape != self.pos_embed.shape:
                print('load_param.upper() -- {}; {}'.format(v.shape, self.pos_embed.shape))
                # To resize pos embedding when using model at different size from pretrained weights
                if 'distilled' in model_path:
                    print('distill need to choose right cls token in the pth')
                    v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1)
                v = resize_pos_embed(v, self.pos_embed, self.patch_embed.num_y, self.patch_embed.num_x)
                # self.state_dict()[k].copy_(revise)
            try:
                self.state_dict()[k].copy_(v)
            except:
                print('===========================ERROR=========================')
                print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape))
