import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.beta import Beta

__all__ = ['ProtoNetLinear']

def get_embedding(dataset):
    if dataset in ['esc50']:
        in_features = 640 if dataset=='esc50' else 1024
        emb = nn.Sequential(
                nn.Sequential(
                    nn.Linear(in_features=in_features, out_features=500),
                    nn.BatchNorm1d(num_features=500),
                    nn.LeakyReLU(negative_slope=0.01, inplace=True),
                    ),
                nn.Sequential(
                    nn.Linear(in_features=500, out_features=500),
                    nn.BatchNorm1d(num_features=500),
                    nn.LeakyReLU(negative_slope=0.01, inplace=True),
                    ),
                nn.Linear(in_features=500, out_features=500),
                )
    else:
        raise NotImplementedError('Embedding function for {} not implemented'.format(name))
    return emb

class ProtoNetLinear(nn.Module):
    def __init__(self, alpha=None, beta=None, name='', dataset=''):
        super(ProtoNetLinear, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.name = name
        self.embedding = get_embedding(dataset=dataset)
        
    def get_num_samples(self, labels, num_classes, dtype=None):
        B = labels.size(0)
        with torch.no_grad():
            ones = torch.ones_like(labels, dtype=dtype)
            num_samples = ones.new_zeros((B, num_classes))
            num_samples.scatter_add_(1, labels, ones)
        return num_samples

    def make_prototypes(self, embeddings, labels, nways):
        B, _, H = embeddings.size()
        
        num_samples = self.get_num_samples(labels=labels, num_classes=nways, dtype=embeddings.dtype)

        num_samples.unsqueeze_(-1)
        num_samples = torch.max(num_samples, torch.ones_like(num_samples))

        prototypes = embeddings.new_zeros((B, nways, H))
        indices = labels.unsqueeze(-1).expand_as(embeddings)
        prototypes.scatter_add_(1, indices, embeddings).div_(num_samples)
        return prototypes
    
    def prototypical_loss(self, prototypes, embeddings, labels, **kwargs):
        distances = torch.sum((prototypes.unsqueeze(2) - embeddings.unsqueeze(1)) ** 2, dim=-1)
        return F.cross_entropy(-distances, labels, **kwargs)

    def get_accuracy(self, prototypes, embeddings, labels):
        sq_distances = torch.sum((prototypes.unsqueeze(1) - embeddings.unsqueeze(2)) ** 2, dim=-1)
        _, predictions = torch.min(sq_distances, dim=-1)
        return torch.mean(predictions.eq(labels).float(), 1).cpu().tolist()
    
    def additive_noise(self, support, query, lmbda=0.1, add_noise='sq'):
            if 's' in add_noise:
                support = support + lmbda * torch.rand_like(support)
            if 'q' in add_noise:
                query = query + lmbda * torch.rand_like(query)
            return support, query

    def interpolate(self, support, query, interpolator):
        if 's' in interpolator.mix:
            support = interpolator(support)
        else:
            support = interpolator(support[:, torch.randperm(support.size(1))[0], :].unsqueeze(dim=1))

        if 'q' in interpolator.mix:
            query = interpolator(query)
        else:
            query = interpolator(query[:, torch.randperm(query.size(1))[0], :].unsqueeze(dim=1))
        return support, query
    
    def forward_set_task_single_set(self, support, slabel, query, qlabel, interpolator):
        B, s_ways, s_shots, H = support.size()
        _, q_ways, q_shots, _ = query.size()
        
        support = support.view(B*s_ways*s_shots, H)
        query = query.view(B*q_ways*q_shots, H)
        
        layer = interpolator.layer
        
        if layer == -1:
            support, query = self.interpolate(support=support.unsqueeze(dim=1), query=query.unsqueeze(dim=1), interpolator=interpolator)
            support = self.embedding(support.view(-1, H))
            query = self.embedding(query.view(-1, H))
        elif layer == 0:
            support = self.embedding[0](support)
            query = self.embedding[0](query)
            
            _, H_ = support.size()
            support, query = self.interpolate(support=support.unsqueeze(dim=1), query=query.unsqueeze(dim=1), interpolator=interpolator)
            support, query = support.view(-1, H_), query.view(-1, H_)

            for i in range(1, len(self.embedding)):
                support = self.embedding[i](support)
                query = self.embedding[i](query)
        elif layer == 1:
            for i in range(0, 2):
                support = self.embedding[i](support)
                query = self.embedding[i](query)
            
            _, H_ = support.size()
            support, query = self.interpolate(support=support.unsqueeze(dim=1), query=query.unsqueeze(dim=1), interpolator=interpolator)
            support, query = support.view(-1, H_), query.view(-1, H_)
            support, query = self.embedding[2](support), self.embedding[2](query)
        elif layer == 2:
            support = self.embedding(support)
            query = self.embedding(query)
            
            _, H_ = support.size()
            support, query = self.interpolate(support=support.unsqueeze(dim=1), query=query.unsqueeze(dim=1), interpolator=interpolator)
            support, query = support.view(-1, H_), query.view(-1, H_)
        return support.view(B, s_ways*s_shots, -1), query.view(B, q_ways*q_shots, -1), slabel, qlabel

    def forward_set_task_interpolation_double_forward(self, support, slabel, query, qlabel, interpolator):
        B, p, s_ways, s_shots, H = support.size()
        _, _, q_ways, q_shots, _ = query.size()
        
        support = torch.split(support, split_size_or_sections=1, dim=1)
        support = [s.contiguous().view(B*s_ways*s_shots, H) for s in support]
        slabel = torch.split(slabel, split_size_or_sections=1, dim=1)
        
        query = torch.split(query, split_size_or_sections=1, dim=1)
        query = [q.contiguous().view(B*q_ways*q_shots, H) for q in query]
        qlabel = torch.split(qlabel, split_size_or_sections=1, dim=1)
        
        layer = interpolator.layer

        slabel = torch.cat([torch.arange(0, slabel[0].size(2))[torch.randperm(slabel[0].size(2))].view(slabel[0].size(2), 1).repeat(1, slabel[0].size(3)).unsqueeze(0).long() for _ in range(B)], dim=0)
        
        if layer == -1:
            support = torch.cat([s.unsqueeze(1) for s in support], dim=1)
            query = torch.cat([q.unsqueeze(1) for q in query], dim=1)
            
            _, _, H_ = support.size()
            
            support, query = self.additive_noise(support=support, query=query, add_noise=interpolator.noise)
            support, query = self.interpolate(support=support, query=query, interpolator=interpolator)
           
            support = self.embedding(support.view(-1, H_))
            query = self.embedding(query.view(-1, H_))
        elif layer == 0:
            for k in range(len(support)):
                support[k] = self.embedding[0](support[k])
            for k in range(len(query)):
                query[k] = self.embedding[0](query[k])
            
            support = torch.cat([s.unsqueeze(dim=1) for s in support], dim=1)
            query = torch.cat([q.unsqueeze(dim=1) for q in query], dim=1)
            
            _, _, H_ = support.size()

            support, query = self.additive_noise(support=support, query=query, add_noise=interpolator.noise)
            support, query = self.interpolate(support=support, query=query, interpolator=interpolator)
            support, query = support.view(-1, H_), query.view(-1, H_)
            for i in range(1, len(self.embedding)):
                support = self.embedding[i](support)
                query = self.embedding[i](query)
        elif layer == 1:
            for i in range(0, 2):
                for k in range(len(support)):
                    support[k] = self.embedding[i](support[k])
                for k in range(len(query)):
                    query[k] = self.embedding[i](query[k])
            support = torch.cat([s.unsqueeze(dim=1) for s in support], dim=1)
            query = torch.cat([q.unsqueeze(dim=1) for q in query], dim=1)
            
            _, _, H_ = support.size()
            
            support, query = self.additive_noise(support=support, query=query, add_noise=interpolator.noise)
            support, query = self.interpolate(support=support, query=query, interpolator=interpolator)
            support, query = support.view(-1, H_), query.view(-1, H_)
            support, query = self.embedding[2](support), self.embedding[2](query)
        elif layer == 2:
            for i in range(len(self.embedding)):
                for k in range(len(support)):
                    support[k] = self.embedding[i](support[k])
                for k in range(len(query)):
                    query[k] = self.embedding[i](query[k])
            
            support = torch.cat([s.unsqueeze(dim=1) for s in support], dim=1)
            query = torch.cat([q.unsqueeze(dim=1) for q in query], dim=1)
            
            _, _, H_ = support.size()

            support, query = self.additive_noise(support=support, query=query, add_noise=interpolator.noise)
            support, query = self.interpolate(support=support, query=query, interpolator=interpolator)
            support, query = support.view(-1, H_), query.view(-1, H_)
        support, query = support.view(B, s_ways*s_shots, -1), query.view(B, q_ways*q_shots, -1)
        
        return support, query, slabel.clone().to(support.device), slabel[:, :, :1].clone().repeat(1, 1, q_shots).to(support.device)
    
    def forward(self, support, slabel, query, qlabel, interpolator=None):
        if 'settaskinterpolator' in self.name:
            if interpolator.name == 'double_forward':
                if len(support.size()) == 4:
                    B, s_ways, s_shots, H = support.size()
                    _, q_ways, q_shots, _ = query.size()
                    s_embedding, q_embedding, slabel, qlabel = self.forward_set_task_single_set(support=support, slabel=slabel, query=query, qlabel=qlabel, interpolator=interpolator)
                else:
                    B, _, s_ways, s_shots, H = support.size()
                    _, _, q_ways, q_shots, _ = query.size()
                    s_embedding, q_embedding, slabel, qlabel = self.forward_set_task_interpolation_double_forward(support=support, slabel=slabel, query=query, qlabel=qlabel, interpolator=interpolator)
        else:
            raise NotImplementedError('{} not implemented'.format(self.name))
        
        prototypes = self.make_prototypes(embeddings=s_embedding, labels=slabel.view(B, s_ways*s_shots), nways=s_ways)
        loss = self.prototypical_loss(prototypes=prototypes, embeddings=q_embedding, labels=qlabel.view(B, q_ways*q_shots))
        accuracy = self.get_accuracy(prototypes=prototypes, embeddings=q_embedding, labels=qlabel.view(B, q_ways*q_shots))
        return loss, accuracy
