import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.transforms import Normalize

class MoCo(nn.Module):
    def __init__(self, obs_shape, ckpt_file=None, load_ckpt=False, base_encoder=models.__dict__['resnet50'], dim=128, K=65536, m=0.999, T=0.07, mlp=True):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()
        assert len(obs_shape) == 3
        self.repr_dim = 2048 * 3 * 3
        self.load_ckpt = load_ckpt
        self.ckpt_file = ckpt_file
        self.K = K
        self.m = m
        self.T = T
        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=1000, pretrained=True)
        self.encoder_k = base_encoder(num_classes=1000, pretrained=True)
        self.img_norm = Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                                    std=torch.tensor([0.229, 0.224, 0.225]))
        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, dim)
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), nn.Linear(dim_mlp, dim)
            )

        if load_ckpt:
            ckpt = torch.load(ckpt_file)
            if 'state_dict' in ckpt.keys():
                ckpt = ckpt['state_dict']
            # print(ckpt.keys())
            new_ckpt = {}
            for k in ckpt.keys():
                if k.startswith('module.encoder_q'):
                    new_ckpt[k[17:]] = ckpt[k]
                if k.startswith('encoder_q'):
                    new_ckpt[k[10:]] = ckpt[k]
            ckpt = new_ckpt
            self.encoder_q.load_state_dict(ckpt)
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer
        self.queue_ptr[0] = ptr
        
    def get_feature(self, obs, spacial=True, normalize=True):
        obs = obs[:,-3:]/255.0
        if normalize:
            obs = self.img_norm(obs)
        if not spacial:
            h = self.encoder_q(obs)
            h = h.view(obs.shape[0], -1)
            return h
        else:
            h = obs
        i = 0
        for m in list(self.encoder_q.children()):
            i += 1
            if i <= 8:
                h = m(h)
        h = h.view(obs.shape[0], -1)
        return h
    
    def forward_feature(self, obs):
        obs = obs[:,:3]/255.0
        obs = self.img_norm(obs)
        return self.encoder_q(obs)
    

    def forward(self, im_q, im_k, get_features=False):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        im_q = im_q[:,:3]/255.0
        im_q = self.img_norm(im_q)
        im_k = im_k[:,:3]/255.0
        im_k = self.img_norm(im_k)
        if get_features:
            return self.encoder_q(im_q)
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder
            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)         
        
        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels
    
    def save_ckpt(self, file_name):
        torch.save(self.state_dict(), file_name)
        print('Ckpt ',file_name, ' saved!')