import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from models import SSS, RandomSelection

__all__ = ['ANPSubsetSelect']

class DeterministicANP(nn.Module):
    def __init__(self,encoder_num_layers, decoder_num_layers, x_dim, y_dim, hidden_dim):
        super(DeterministicANP, self).__init__()
        self.DeterministicEncoder = DeterministicEncoder(x_dim,y_dim,hidden_dim,hidden_dim)
        self.LatentEncoder = LatentEncoder(x_dim,y_dim,hidden_dim,hidden_dim)
        self.Decoder = Decoder(x_dim,y_dim,hidden_dim)
    
    def forward(self,context_x, context_y, target_x, target_y=None, mask=None):
        num_targets = target_x.size(1)
        latent = self.LatentEncoder(context_x,context_y,mask)
        z = latent.unsqueeze(1).repeat(1,num_targets,1)
        r = self.DeterministicEncoder(context_x, context_y, target_x, mask)
        dist, mu, sigma = self.Decoder(r,z,target_x)
        if target_y is not None:
            log_p = dist.log_prob(target_y)
        else:
            log_p = None
        return log_p, mu, sigma

class Linear(nn.Module):
    """
    Linear Module
    """
    def __init__(self, in_dim, out_dim, bias=True, w_init='linear'):
        """
        :param in_dim: dimension of input
        :param out_dim: dimension of output
        :param bias: boolean. if True, bias is included.
        :param w_init: str. weight inits with xavier initialization.
        """
        super(Linear, self).__init__()
        self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)

        nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=nn.init.calculate_gain(w_init))

    def forward(self, x):
        return self.linear_layer(x)

class LatentEncoder(nn.Module):
    """
    Latent Encoder [For prior, posterior]
    """
    def __init__(self, x_dim, y_dim, num_hidden, num_latent):
        super(LatentEncoder, self).__init__()
        self.input_projection = Linear(x_dim+y_dim, num_hidden)
        self.self_attentions = nn.ModuleList([Attention(num_hidden) for _ in range(2)])
        self.penultimate_layer = Linear(num_hidden, num_hidden, w_init='relu')
        #self.mu = Linear(num_hidden, num_latent)
        #self.log_sigma = Linear(num_hidden, num_latent)

    def forward(self, x, y, mask):
        # concat location (x) and value (y)
        encoder_input = torch.cat([x,y], dim=-1)
        
        # project vector with dimension 3 --> num_hidden
        encoder_input = self.input_projection(encoder_input)
        
        # self attention layer
        for attention in self.self_attentions:
            encoder_input, _ = attention(encoder_input, encoder_input, encoder_input, mask)
        
        # mean
        hidden = torch.sum(encoder_input*mask,1)/(mask.sum(1)+1e-10)
        #hidden = t.relu(self.penultimate_layer(hidden))
        
        ## get mu and sigma
        #mu = self.mu(hidden)
        #log_sigma = self.log_sigma(hidden)
        #
        ## reparameterization trick
        #std = t.exp(0.5 * log_sigma)
        #eps = t.randn_like(std)
        #z = eps.mul(std).add_(mu)
        #
        ## return distribution
        #return mu, log_sigma, z

        return hidden
    
class DeterministicEncoder(nn.Module):
    """
    Deterministic Encoder [r]
    """
    def __init__(self, x_dim, y_dim, num_hidden, num_latent):
        super(DeterministicEncoder, self).__init__()
        self.self_attentions = nn.ModuleList([Attention(num_hidden) for _ in range(2)])
        self.cross_attentions = nn.ModuleList([Attention(num_hidden) for _ in range(2)])
        self.input_projection = Linear(x_dim+y_dim, num_hidden)
        self.context_projection = Linear(x_dim, num_hidden)
        self.target_projection = Linear(x_dim, num_hidden)

    def forward(self, context_x, context_y, target_x, mask):
        # concat context location (x), context value (y)
        encoder_input = torch.cat([context_x,context_y], dim=-1)
        
        # project vector with dimension 3 --> num_hidden
        encoder_input = self.input_projection(encoder_input)

        # self attention layer
        for attention in self.self_attentions:
            encoder_input, _ = attention(encoder_input, encoder_input, encoder_input, mask)
        
        # query: target_x, key: context_x, value: representation
        query = self.target_projection(target_x)
        keys = self.context_projection(context_x)

        # cross attention layer
        for attention in self.cross_attentions:
            query, _ = attention(keys, encoder_input, query, mask)
        
        return query
    
class Decoder(nn.Module):
    """
    Decoder for generation 
    """
    def __init__(self, x_dim, y_dim, num_hidden):
        super(Decoder, self).__init__()
        self.target_projection = Linear(x_dim, num_hidden)
        self.linears = nn.ModuleList([Linear(num_hidden * 3, num_hidden * 3, w_init='relu') for _ in range(3)])
        self.final_projection = Linear(num_hidden * 3, y_dim*2)
        self._y_dim = y_dim 
    def forward(self, r, z, target_x):
        batch_size, num_targets, _ = target_x.size()
        # project vector with dimension 2 --> num_hidden
        target_x = self.target_projection(target_x)
        
        # concat all vectors (r,z,target_x)
        hidden = torch.cat([torch.cat([r,z], dim=-1), target_x], dim=-1)
        
        # mlp layers
        for linear in self.linears:
            hidden = torch.relu(linear(hidden))
            
        # get mu and sigma
        hidden = self.final_projection(hidden)

        mu, log_sigma = torch.split(hidden, self._y_dim, -1)
        sigma = 0.1 + 0.9 * F.softplus(log_sigma)
        dist = torch.distributions.normal.Normal(loc=mu, scale=sigma)

        
        return dist, mu, sigma

class MultiheadAttention(nn.Module):
    """
    Multihead attention mechanism (dot attention)
    """
    def __init__(self, num_hidden_k):
        """
        :param num_hidden_k: dimension of hidden 
        """
        super(MultiheadAttention, self).__init__()

        self.num_hidden_k = num_hidden_k

    def forward(self, key, value, query, mask):
        # Get attention score
        attn = torch.bmm(query, key.transpose(1, 2))
        attn = attn / math.sqrt(self.num_hidden_k)
        
        mask = mask.view([mask.shape[0],mask.shape[2],mask.shape[1]])
        attn = torch.sigmoid(attn)*mask
        attn = attn / (torch.sum(attn,2,keepdims=True)+1e-10)

        # Get Context Vector
        result = torch.bmm(attn, value)

        return result, attn


class Attention(nn.Module):
    """
    Attention Network
    """
    def __init__(self, num_hidden, h=4):
        """
        :param num_hidden: dimension of hidden
        :param h: num of heads 
        """
        super(Attention, self).__init__()

        self.num_hidden = num_hidden
        self.num_hidden_per_attn = num_hidden // h
        self.h = h

        self.key = Linear(num_hidden, num_hidden, bias=False)
        self.value = Linear(num_hidden, num_hidden, bias=False)
        self.query = Linear(num_hidden, num_hidden, bias=False)

        self.multihead = MultiheadAttention(self.num_hidden_per_attn)

        self.residual_dropout = nn.Dropout(p=0.1)

        self.final_linear = Linear(num_hidden * 2, num_hidden)

        self.layer_norm = nn.LayerNorm(num_hidden)

    def forward(self, key, value, query, mask):

        batch_size = key.size(0)
        seq_k = key.size(1)
        seq_q = query.size(1)
        residual = query

        # Make multihead
        key = self.key(key).view(batch_size, seq_k, self.h, self.num_hidden_per_attn)
        value = self.value(value).view(batch_size, seq_k, self.h, self.num_hidden_per_attn)
        query = self.query(query).view(batch_size, seq_q, self.h, self.num_hidden_per_attn)

        key = key.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn)
        value = value.permute(2, 0, 1, 3).contiguous().view(-1, seq_k, self.num_hidden_per_attn)
        query = query.permute(2, 0, 1, 3).contiguous().view(-1, seq_q, self.num_hidden_per_attn)

        # Get context vector
        mask = mask.repeat([self.h,1,1])
        result, attns = self.multihead(key, value, query, mask)

        # Concatenate all multihead context vector
        result = result.view(self.h, batch_size, seq_q, self.num_hidden_per_attn)
        result = result.permute(1, 2, 0, 3).contiguous().view(batch_size, seq_q, -1)
        
        # Concatenate context vector with input (most important)
        result = torch.cat([residual, result], dim=-1)
        
        # Final linear
        result = self.final_linear(result)

        # Residual dropout & connection
        result = self.residual_dropout(result)
        result = result + residual

        # Layer normalization
        result = self.layer_norm(result)

        return result, attns

class ANPSubsetSelect(nn.Module):
    def __init__(self, args):
        super(ANPSubsetSelect, self).__init__()
        self.args = args
        self.stage = args.stage
        self.x_dim = args.x_dim
        self.y_dim = args.y_dim
        self.element_jump =  args.element_jump
        self.max_output_points = args.max_output_points
        self.train_with_real_mask = args.train_with_real_mask
        
        self.decoder = DeterministicANP(args.CNP_encoder_num_layers, args.CNP_decoder_num_layers, args.x_dim, args.y_dim, args.hidden_dim)
        
        if args.stage in ['candidate', 'autoregressive', 'sss', 'randomautoregressive']:
            self.sss = SSS(num_layers=args.subset_encoder_num_layers, element_dim=args.x_dim+args.y_dim, \
                    hidden_dim=args.hidden_dim, construct_real_mask=args.train_with_real_mask, stage=args.stage, \
                    reg_scale=args.reg_scale, temperature=args.temperature, alpha=args.alpha, thres=args.thres, element_jump=args.element_jump)
        elif args.stage == 'random':
            self.sss = RandomSelection(x_dim=args.x_dim, y_dim=args.y_dim)
        else:
            raise NotImplementedError()

        self.name = self.sss.name

    def forward(self, context, target, subset_size=15):
        if self.training:
            k = np.random.choice(self.max_output_points) + 1
        else:
            k = subset_size
        
        if self.training:
            D = torch.cat([context, target], dim=2)
            B, S, H = D.size()
            if self.stage == 'candidate':
                candidate_mask, candidate_mask_real, candidate_reg = self.sss(D=D)
                log_p, mu, var      = self.decoder(context_x=context, context_y=target, target_x=context, target_y=target, mask=candidate_mask)
                log_p_real, _, _    = self.decoder(context_x=context, context_y=target, target_x=context, target_y=target, mask=candidate_mask_real)
                return log_p, log_p_real, candidate_mask, candidate_mask_real, candidate_reg
            elif self.stage == 'autoregressive':
                subset_mask, subset_mask_real   = self.sss(D=D, k=k)
                log_p, mu, var                  = self.decoder(context_x=context, context_y=target, target_x=context, target_y=target, mask=subset_mask)
                log_p_real, _, _                = self.decoder(context_x=context, context_y=target, target_x=context, target_y=target, mask=subset_mask_real)
                return log_p, log_p_real, subset_mask, subset_mask_real
            elif self.stage == 'sss':
                candidate_mask, candidate_mask_real, candidate_reg, subset_mask, subset_mask_real = self.sss(D=D, k=k)
                log_p, mu, var      = self.decoder(context_x=context, context_y=target, target_x=context, target_y=target, mask=subset_mask)
                log_p_real, _, _    = self.decoder(context_x=context, context_y=target, target_x=context, target_y=target, mask=subset_mask_real)
                return log_p, log_p_real, candidate_mask, candidate_mask_real, candidate_reg, subset_mask, subset_mask_real 
            elif self.stage == 'randomautoregressive':
                random_mask, subset_mask, subset_mask_real = self.sss(D=D, k=k)
                log_p, mu, var      = self.decoder(context_x=context, context_y=target, target_x=context, target_y=target, mask=subset_mask)
                log_p_real, _, _    = self.decoder(context_x=context, context_y=target, target_x=context, target_y=target, mask=subset_mask_real)
                return log_p, log_p_real, random_mask, subset_mask, subset_mask_real 
            elif self.stage == 'random':
                context_subset, target_subset, mask = self.sss(D=D, k=k)
                log_p, mu, var = self.decoder(context_x=context_subset, context_y=target_subset, target_x=context, target_y=target, mask=mask)
                return log_p, mask 
            else:
                raise NotImplementedError()
        else:
            context_x, context_y = context
            target_x, target_y = target
            D = torch.cat([context_x, context_y], dim=2)
            B, S, H = D.size()

            if self.stage == 'candidate':
                candidate_mask = self.sss(D=D, k=k)
                log_p, mu, var = self.decoder(context_x=context_x, context_y=context_y, target_x=target_x, target_y=target_y, mask=candidate_mask)
                return log_p, [mu, var, candidate_mask]
            elif self.stage == 'autoregressive':
                subset_mask = self.sss(D=D, k=k)
                log_p, mu, var = self.decoder(context_x=context_x, context_y=context_y, target_x=target_x, target_y=target_y, mask=subset_mask)
                return log_p, [mu, var, subset_mask]
            elif self.stage == 'sss':
                subset_mask, candidate_mask = self.sss(D=D, k=k)
                log_p, mu, var = self.decoder(context_x=context_x, context_y=context_y, target_x=target_x, target_y=target_y, mask=subset_mask)
                return log_p, [mu, var, subset_mask, candidate_mask]
            elif self.stage == 'randomautoregressive':
                subset_mask, random_mask = self.sss(D=D, k=k)
                log_p, mu, var = self.decoder(context_x=context_x, context_y=context_y, target_x=target_x, target_y=target_y, mask=subset_mask)
                return log_p, [mu, var, subset_mask, random_mask]
            elif self.stage == 'random':
                random_mask = self.sss(D=D, k=k)
                log_p, mu, var = self.decoder(context_x=context_x, context_y=context_y, target_x=target_x, target_y=target_y, mask=random_mask)
                return log_p, [mu, var, random_mask]
            else:
                raise NotImplementedError()
