import sys
sys.path.append('..')
import torch
from torch import autograd, optim, nn
from torch.autograd import Variable
from torch.nn import functional as F
import random
from protonets.models import register_model
import torchvision.models as models

from .utils import euclidean_dist

class ProtoParamModule(nn.Module):
    def __init__(self, dot=False):
        nn.Module.__init__(self)
        self.center_d = nn.ParameterDict({})
        self.radius_d = nn.ParameterDict({})
        self.dot = dot
    
    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            # return torch.sqrt(torch.pow(x - y, 2).sum(dim))
            return (torch.pow(x - y, 2)).sum(dim)

    def __init_proto_center__(self, support_embeddings):
        proto_center = torch.mean(support_embeddings, 0) 
        return proto_center
    
    def __init_proto_radius__(self, proto_center, support_embeddings):
        mean_dist_to_c = torch.mean(self.__dist__(support_embeddings, proto_center, dim=1)).detach()
        proto_radius = mean_dist_to_c / 2
        if proto_radius == 0:
            proto_radius = torch.tensor(10.00)
        elif proto_radius < 0:
            raise ValueError('negative proto_radius')
        assert(proto_radius > 0)
        if torch.cuda.is_available():
            proto_radius = proto_radius.cuda()
        return nn.Parameter(proto_radius)

    
    def forward(self, tag, support_embeddings, save_proto=True):
        '''
        given a tag and support sample embeddings
        return the proto parameter (center, radius)
        '''
        proto_center = self.__init_proto_center__(support_embeddings)
        if torch.cuda.is_available():
            proto_center = proto_center.cuda()
        if tag not in self.radius_d:
            proto_radius = self.__init_proto_radius__(proto_center, support_embeddings)
            if save_proto:
                self.radius_d[tag] = proto_radius
            return proto_center, proto_radius
        else:
            return proto_center, self.radius_d[tag]

class ProtoParamModule_d(nn.Module):
    def __init__(self, dot=False):
        nn.Module.__init__(self)
        self.center_d = nn.ParameterDict({})
        self.radius_d = nn.ParameterDict({})
        self.dot = dot
    
    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            # return torch.sqrt(torch.pow(x - y, 2).sum(dim))
            return (torch.pow(x - y, 2)).sum(dim)

    def __init_proto_value__(self, support_embeddings):
        proto_center = torch.mean(support_embeddings, 0).detach() 
        mean_dist_to_c = torch.mean(self.__dist__(support_embeddings, proto_center, dim=1)).detach()
        proto_radius = mean_dist_to_c / 2
        if proto_radius == 0:
            proto_radius = torch.tensor(10.00)
        elif proto_radius < 0:
            raise ValueError('negative proto_radius')
        assert(proto_radius > 0)
        return nn.Parameter(proto_center), nn.Parameter(proto_radius)

    
    def forward(self, tag, support_embeddings, save_proto=True):
        '''
        given a tag and support sample embeddings
        return the proto parameter (center, radius)
        '''
        if tag not in self.center_d:
            c, r = self.__init_proto_value__(support_embeddings)
            if save_proto:
                self.center_d[tag] = c
                self.radius_d[tag] = r
            return c, r
        else:
            return self.center_d[tag], self.radius_d[tag]



class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)

class BigProtonet(nn.Module):
    def __init__(self, encoder):
        super(BigProtonet, self).__init__()
        
        self.encoder = encoder
        self.proto_param = ProtoParamModule()
    
    def __dist__(self, x, y, dim):
        if self.dot:
            return (x * y).sum(dim)
        else:
            return torch.pow(x - y, 2).sum(dim)

    def __batch_dist__(self, C, R, Q, q_mask):
        # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim]
        assert Q.size()[:2] == q_mask.size()
        Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim]
        dist_to_proto_ct = self.__dist__(C.unsqueeze(0), Q.unsqueeze(1), 2)
        dist_to_proto = dist_to_proto_ct - R
        return -dist_to_proto
    
    def __get_all_protos__(self, support_embed, classes, save_proto=True):
        proto_centers = []
        proto_radii = []
        for j, label in enumerate(classes):
            support_embeddings = support_embed[j]
            c, r = self.proto_param(label, support_embeddings, save_proto=save_proto)
            proto_centers.append(c)
            proto_radii.append(r)
        
        proto_centers = torch.stack(proto_centers)
        proto_radii = torch.stack(proto_radii)
        
        return proto_centers, proto_radii
    
    def __get_all_avg_protos__(self, support_embed, classes):
        # support_embed (n_class, n_support, z_dim)

        proto_centers = support_embed.mean(1) # (n_class, z_dim)
        expanded_proto_centers = proto_centers.unsqueeze(1).expand(support_embed.size())
        proto_radii = torch.pow(support_embed - expanded_proto_centers, 2).sum(-1).mean(-1) / 2
        
        return proto_centers, proto_radii
    
    def init_proto(self):
        pass

    def forward_full_supervised(self, batch):
        pass

    def loss(self, sample, eval=False):
        xs = Variable(sample['xs']) # support
        xq = Variable(sample['xq']) # query

        n_class = xs.size(0)
        assert xq.size(0) == n_class
        n_support = xs.size(1)
        n_query = xq.size(1)


        target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(n_class, n_query, 1).long()
        target_inds = Variable(target_inds, requires_grad=False)

        if xq.is_cuda:
            target_inds = target_inds.cuda()

        x = torch.cat([xs.view(n_class * n_support, *xs.size()[2:]),
                       xq.view(n_class * n_query, *xq.size()[2:])], 0)

        # if eval:
        # batchsize = 1
        # num = int(x.size(0) / batchsize)
        # z = []
        # for i in range(num):
        #     z_tmp = self.encoder.forward(x[batchsize*i:batchsize*(i+1)])
        #     z.append(z_tmp)
        # z = torch.cat(z, dim=0)
        # else:
        z = self.encoder.forward(x)
        # print(z.size())
        z_dim = z.size(-1)
        z_embedding = z[:n_class*n_support].view(n_class, n_support, z_dim)
        assert z_embedding.size(0) == len(sample['class'])

        if not eval:
            proto_center, proto_radius = self.__get_all_protos__(z_embedding, sample['class'])
        else:
            proto_center, proto_radius = self.__get_all_avg_protos__(z_embedding, sample['class'])

        if torch.cuda.is_available():
            proto_center, proto_radius = proto_center.cuda(), proto_radius.cuda()

        zq = z[n_class*n_support:]
        dists = euclidean_dist(zq, proto_center, do_sqrt=False)
        # print('distance1:', dists[0])
        # print('proto radius:', proto_radius)
        dists = dists - proto_radius
        # print('distance2:', dists[0])

        log_p_y = F.log_softmax(-dists, dim=1).view(n_class, n_query, -1)

        loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()

        _, y_hat = log_p_y.max(2)
        acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
        # print('target label:', target_inds)
        # print(y_hat)
        # print('radius:\n')
        # for k, v in self.proto_param.radius_d.items():
        #     print(k,v)
        # proto_center_tmp, proto_radius_tmp = self.__get_all_avg_protos__(z_embedding, sample['class'])
        # classes = []
        # for c in sample['class']:
        #     classes += [c] * n_query
        assert len(sample['class']) == len(proto_radius)

        dists_support = euclidean_dist(z[:n_class*n_support], proto_center, do_sqrt=False)
        dists_support = dists_support - proto_radius


        return loss_val, {
            'loss': loss_val.item(),
            'acc': acc_val.item(),
            # 'proto_center': proto_center.data,
            # 'proto_center_this_batch': proto_center_tmp.data,
            'proto_radius': proto_radius.data,
            # 'proto_radius_this_batch': proto_radius_tmp,
            'dists': dists_support.data,
            'class': sample['class']
        }

@register_model('bigprotonet_conv')
def load_bigprotonet_conv(**kwargs):
    x_dim = kwargs['x_dim']
    hid_dim = kwargs['hid_dim']
    z_dim = kwargs['z_dim']

    def conv_block(in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

    encoder = nn.Sequential(
        conv_block(x_dim[0], hid_dim),
        conv_block(hid_dim, hid_dim),
        conv_block(hid_dim, hid_dim),
        conv_block(hid_dim, z_dim),
        Flatten()
    )

    return BigProtonet(encoder)

@register_model('bigprotonet_resnet50')
def load_bigprotonet_resnet50(**kwargs):
    x_dim = kwargs['x_dim']
    hid_dim = kwargs['hid_dim']
    z_dim = kwargs['z_dim']

    def resnet50():
        encoder = models.resnet50(pretrained=True)
        dim_representation = encoder.fc.in_features
        # dim_mlp = encoder.fc.weight.shape[1]
        encoder.fc = nn.Sequential(
            nn.Linear(dim_representation, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, z_dim),
            Flatten()
        )
        return encoder
    
    encoder = resnet50()

    return BigProtonet(nn.DataParallel(encoder))

@register_model('bigprotonet_resnet18')
def load_bigprotonet_resnet18(**kwargs):
    x_dim = kwargs['x_dim']
    hid_dim = kwargs['hid_dim']
    z_dim = kwargs['z_dim']

    def resnet18():
        encoder = models.resnet18(pretrained=True)
        dim_representation = encoder.fc.in_features
        # dim_mlp = encoder.fc.weight.shape[1]
        encoder.fc = nn.Sequential(
            nn.Linear(dim_representation, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, z_dim),
            Flatten()
        )
        return encoder
    
    encoder = resnet18()

    return BigProtonet(nn.DataParallel(encoder))