import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.multinomial import Multinomial
from torch.distributions.relaxed_bernoulli import RelaxedBernoulli
from torch.distributions.one_hot_categorical import OneHotCategorical
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical

from models import MAB, RandomSelection

__all__ = ['SSS']

def select_random_k_batch(mask, k, thres=0.5):
    #TODO: Can we speed this up?
    B, S, H = mask.size()
    mask = mask.ge(thres).float()
    random_mask = torch.zeros_like(mask)
    for i in range(B):
        b = mask[i].squeeze()
        nonzero = torch.nonzero(b, as_tuple=True)[0]
        randperm = nonzero[torch.randperm(nonzero.size(0))[:k]]
        random_mask[i, randperm, :] = 1.0
    return random_mask

class CandidateSelection(nn.Module):
    def __init__(self, element_dim=256, hidden_dim=128, num_gates=1, reg_scale=0.01, temperature=0.05, alpha=1e-1, thres=0.499, construct_real_mask=True):
        super(CandidateSelection, self).__init__()
        self.alpha = alpha
        self.thres = thres
        self.num_gates = num_gates
        self.reg_scale = reg_scale
        self.hidden_dim = hidden_dim
        self.element_dim = element_dim
        self.temperature = temperature
        self.construct_real_mask = construct_real_mask

        self.probs = nn.Sequential(
                nn.BatchNorm1d(num_features=element_dim),
                nn.Linear(in_features=element_dim, out_features=hidden_dim, bias=False),
                nn.BatchNorm1d(num_features=hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(in_features=hidden_dim, out_features=hidden_dim, bias=False),
                nn.BatchNorm1d(num_features=hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(in_features=hidden_dim, out_features=num_gates, bias=False),
                nn.BatchNorm1d(num_features=num_gates)
                )
        self.probs.apply(self.init_weights)

    @staticmethod
    def init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)

    def forward(self, D, topk=False, k=0):
        B, S, H = D.size()              
        D = D.view(B*S, H)

        probs = torch.sigmoid(self.probs(D))
        probs = probs.view(B, S, self.num_gates)
        probs = 1e-10 + (1-2e-10) * probs
        
        if topk:
            m = Multinomial(total_count=k, probs=probs.squeeze(2))
            z = m.sample().unsqueeze(2)
            return z

        if self.training:
            p_z = RelaxedBernoulli(torch.Tensor([self.temperature]).to(D.device), probs=probs)
            z = p_z.rsample()
        else:
            z = Bernoulli(probs=probs).sample()

        reg = probs * torch.log(probs / self.alpha + 1e-10) + (1-probs) * torch.log( (1-probs) / (1 - self.alpha) + 1e-10)
        z_real = None
        if self.construct_real_mask and self.training:
            z_real = Bernoulli(probs=probs).sample().detach()
        return z, z_real, reg*self.reg_scale

class AutoregressiveSelection(nn.Module):
    def __init__(self, num_layers, in_features, hidden_dim, temperature=0.05, construct_real_mask=True):
        super(AutoregressiveSelection, self).__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.in_features = in_features
        self.temperature = temperature
        self.construct_real_mask = construct_real_mask

        self.probs = []
        for i in range(num_layers):
            if i < (num_layers - 1):
                if i == 0:
                    self.probs.append(nn.Linear(in_features=in_features, out_features=hidden_dim))
                else:
                    self.probs.append(nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
                self.probs.append(nn.ReLU(inplace=True))
            else:
                self.probs.append(nn.Linear(in_features=hidden_dim, out_features=1))
                self.probs.append(nn.BatchNorm1d(num_features=1))
                self.probs.append(nn.Sigmoid())
        self.probs = nn.Sequential(*self.probs)
    
    def select_k(self, probs, k):
        if self.training:
            p_z = [RelaxedOneHotCategorical(torch.Tensor([self.temperature]).to(probs.device), probs=probs) for _ in range(k)]
            z = torch.cat([i.rsample().unsqueeze(1) for i in p_z], dim=1).sum(dim=1).unsqueeze(2)
            z[z > 1.5] = 1e-20                  #Remove elements selected more that once.
            
            z_real = None
            if self.construct_real_mask:
                z_real = torch.cat([OneHotCategorical(probs=probs).sample().unsqueeze(1) for _ in range(k)], dim=1).sum(dim=1).unsqueeze(2)
                z_real[z_real > 1.5] = 0.0      #Remove elements selected more that once.
        else:
            z = torch.cat([OneHotCategorical(probs=probs).sample().unsqueeze(1) for _ in range(k)], dim=1).sum(dim=1).unsqueeze(2)
            z[z > 1.5] = 0.0                  #Remove elements selected more that once.
            z_real = None
        return z, z_real

    def forward(self, D, candidate_mask, subset_mask, q):
        B, S, H = D.size()
        D = D.view(B*S, H)
        
        probs = self.probs(D).view(B, S, 1)
        
        if self.training:
            selectable_mask = candidate_mask * (1.0 - subset_mask.ge(0.5).float())
        else:
            selectable_mask = candidate_mask * (1.0 - subset_mask)
        
        #Renormalize probabilities to eliminate unselectable elements.
        probs = selectable_mask * probs 
        probs = (probs / (torch.sum(probs, dim=1, keepdims=True) + 1e-10)).squeeze(2)

        z, z_real = self.select_k(probs=probs, k=q)
        return z, z_real
        
class SSS(nn.Module):
    def __init__(self, num_layers, element_dim, hidden_dim, num_mab=1, num_heads=4, ln=True, construct_real_mask=True, \
            stage='sss', reg_scale=0.01, temperature=0.05, alpha=1e-1, thres=0.499, element_jump=5):
        super(SSS, self).__init__()
        self.ln = ln
        self.stage = stage
        self.num_mab = num_mab
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.temperature = temperature
        self.element_dim = element_dim
        self.element_jump = element_jump
        self.construct_real_mask = construct_real_mask
        
        if stage not in ['candidate', 'autoregressive', 'sss', 'randomautoregressive']:
            raise NotImplementedError('{} not implemented in models/sss.py'.format(stage))

        self.encoder, self.extractor = [], []
        for i in range(num_layers):
            if i == 0:
                in_features = element_dim
            else:
                in_features = hidden_dim
            self.encoder.append(nn.Linear(in_features=in_features, out_features=hidden_dim))
            self.extractor.append(nn.Linear(in_features=in_features, out_features=hidden_dim))
            if i != (num_layers - 1):
                self.encoder.append(nn.ReLU(inplace=True))
                self.extractor.append(nn.ReLU(inplace=True))
        self.encoder = nn.Sequential(*self.encoder)
        self.extractor = nn.Sequential(*self.extractor)

        if stage in ['candidate', 'sss']:
            self.name = 'CandidateSelector' if stage == 'candidate' else 'SSS'
            self.candidateselector = CandidateSelection(element_dim=2*hidden_dim, hidden_dim=hidden_dim, reg_scale=reg_scale, \
                    temperature=temperature, alpha=alpha, thres=thres, construct_real_mask=construct_real_mask)
        
        if stage in ['autoregressive', 'sss', 'randomautoregressive']:
            self.name = 'AutoRegressiveSelector' if stage == 'autoregressive' else 'SSS'
            self.autoregressiveselector = AutoregressiveSelection(num_layers=num_layers, in_features=hidden_dim, \
                    hidden_dim=hidden_dim, construct_real_mask=construct_real_mask)

            self.pairwise_model = []
            for i in range(num_mab):
                dim_Q = hidden_dim
                dim_K = hidden_dim
                self.pairwise_model.append(MAB(dim_Q=dim_Q, dim_K=dim_K, dim_V=hidden_dim, num_heads=num_heads, ln=True))
            self.pairwise_model = nn.Sequential(*self.pairwise_model)
        
        if stage in ['randomautoregressive']:
            self.name = 'RandomAutoRegressiveSelector'
            self.randomselector = RandomSelection()   

    def forward(self, D, k=20, element_jump=1):
        B, S, H = D.size()                                                                              #Batch x SetSize x ElementDimension
        D_hidden = self.extractor(D)
        D_set_representation = torch.mean(self.encoder(D), dim=1, keepdims=True).repeat(1, S, 1)        #Set encode D
        
        if self.stage == 'candidate':
            if self.training:
                candidate_mask, candidate_mask_real, candidate_reg = self.candidateselector(D=torch.cat([D_hidden, D_set_representation], dim=2))
                return candidate_mask, candidate_mask_real, candidate_reg
            else:
                candidate_mask = self.candidateselector(D=torch.cat([D_hidden, D_set_representation], dim=2), topk=True, k=k) 
                return candidate_mask
        elif self.stage == 'autoregressive':
            if self.training:
                #NOTE: During training, we don't iteratively run the autoregressive model since this will be
                #very expensive. Instead, we randomly select k' points and run the autoregressive model to 
                #select q additonal points. This method corresponds to the greedy training algorithm.
                if self.element_jump == 0:
                    k_prime = 0
                    self.element_jump = k
                else:
                    k_prime = int((k + self.element_jump - 1) / self.element_jump) * self.element_jump - self.element_jump
                assert k_prime >= 0, 'k_prime must be >= 0'
                
                candidate_mask = torch.ones(B, S, 1).float().to(D.device)       
                
                #Randomly select k_prime elements.
                if k_prime == 0:
                    subset_mask        = torch.zeros(B, S, 1).to(D.device) + 1e-20
                    subset_mask_real   = torch.zeros(B, S, 1).to(D.device) + 1e-20
                else:
                    subset_mask_real   = select_random_k_batch(mask=candidate_mask.clone().detach(), k=k_prime) 
                    subset_mask        = (subset_mask_real + 1e-20) * candidate_mask 
                
                #Compute pair-wise correspondence.
                D_hidden_s = D_hidden
                for i in range(len(self.pairwise_model)):
                    D_hidden_s = self.pairwise_model[i](Q=D_hidden_s, K=D_hidden_s, mask=subset_mask)
                
                autoregressive_mask, autoregressive_mask_real = self.autoregressiveselector(D=D_hidden_s, candidate_mask=candidate_mask, subset_mask=subset_mask, q=self.element_jump)
                subset_mask         = subset_mask + autoregressive_mask
                subset_mask_real    = subset_mask_real + autoregressive_mask_real
                return subset_mask, subset_mask_real
            else:
                candidate_mask  = torch.ones(B, S, 1).float().to(D.device)
                subset_mask     = torch.zeros(B, S, 1).to(D.device)

                for s in range(int(k/element_jump)):
                    D_hidden_s = D_hidden
                    for i in range(len(self.pairwise_model)):
                        D_hidden_s = self.pairwise_model[i](Q=D_hidden_s, K=D_hidden_s, mask=subset_mask)  
                    autoregressive_mask, _ = self.autoregressiveselector(D=D_hidden_s, candidate_mask=candidate_mask, subset_mask=subset_mask, q=element_jump)
                    subset_mask = subset_mask + autoregressive_mask
                return subset_mask
        elif self.stage == 'randomautoregressive':
            if self.training:
                random_mask = self.randomselector(D=D, k=int(S*0.20)) 
                k_prime = int((k + self.element_jump - 1) / self.element_jump) * self.element_jump - self.element_jump
                assert k_prime >= 0, 'k_prime must be >= 0'
    
                if k_prime == 0:
                    subset_mask        = torch.zeros(B, S, 1).to(D.device) + 1e-20
                    subset_mask_real   = torch.zeros(B, S, 1).to(D.device) + 1e-20
                else:
                    #Randomly select k_prime elements from the candidate set for greedy training
                    subset_mask_real   = select_random_k_batch(mask=random_mask.clone().detach(), k=k_prime)
                    subset_mask        = (subset_mask_real + 1e-20) * random_mask 
                
                #Compute pair-wise correspondence.
                D_hidden_s = D_hidden
                for i in range(len(self.pairwise_model)):
                    D_hidden_s = self.pairwise_model[i](Q=D_hidden_s, K=D_hidden_s, mask=subset_mask)
                
                autoregressive_mask, autoregressive_mask_real = self.autoregressiveselector(D=D_hidden_s, candidate_mask=random_mask, subset_mask=subset_mask, q=self.element_jump) 
                subset_mask         = subset_mask + autoregressive_mask
                subset_mask_real    = subset_mask_real + autoregressive_mask_real
                return random_mask, subset_mask, subset_mask_real
            else:
                random_mask = self.randomselector(D=D, k=int(S*0.10))
                subset_mask = torch.zeros(B, S, 1).to(D.device)
                for s in range(int(k/element_jump)):
                    D_hidden_s = D_hidden
                    for i in range(len(self.pairwise_model)):
                        D_hidden_s = self.pairwise_model[i](Q=D_hidden_s, K=D_hidden_s, mask=subset_mask)  
                    autoregressive_mask, _ = self.autoregressiveselector(D=D_hidden_s, candidate_mask=random_mask, subset_mask=subset_mask, q=element_jump)
                    subset_mask = subset_mask + autoregressive_mask
                return subset_mask, random_mask
        elif self.stage == 'sss':
            #First run the candidate selection stage
            candidate_mask, candidate_mask_real, candidate_reg = self.candidateselector(D=torch.cat([D_hidden, D_set_representation], dim=2))
            
            #Run the autoregressive selection stage. 
            #NOTE: The same greedy training routine disucessed above applies. 
            if self.training:
                k_prime = int((k + self.element_jump - 1) / self.element_jump) * self.element_jump - self.element_jump
                assert k_prime >= 0, 'k_prime must be >= 0'
    
                if k_prime == 0:
                    subset_mask        = torch.zeros(B, S, 1).to(D.device) + 1e-20
                    subset_mask_real   = torch.zeros(B, S, 1).to(D.device) + 1e-20
                else:
                    #Randomly select k_prime elements from the candidate set for greedy training
                    subset_mask_real   = select_random_k_batch(mask=candidate_mask.clone().detach(), k=k_prime)
                    subset_mask        = (subset_mask_real + 1e-20) * candidate_mask 
                
                #Compute pair-wise correspondence.
                D_hidden_s = D_hidden
                for i in range(len(self.pairwise_model)):
                    D_hidden_s = self.pairwise_model[i](Q=D_hidden_s, K=D_hidden_s, mask=subset_mask)
                
                autoregressive_mask, autoregressive_mask_real = self.autoregressiveselector(D=D_hidden_s, candidate_mask=candidate_mask, subset_mask=subset_mask, q=self.element_jump) 
                subset_mask         = subset_mask + autoregressive_mask
                subset_mask_real    = subset_mask_real + autoregressive_mask_real
                return candidate_mask, candidate_mask_real, candidate_reg, subset_mask, subset_mask_real
            else:
                subset_mask = torch.zeros(B, S, 1).to(D.device)
                for s in range(int(k/element_jump)):
                    D_hidden_s = D_hidden
                    for i in range(len(self.pairwise_model)):
                        D_hidden_s = self.pairwise_model[i](Q=D_hidden_s, K=D_hidden_s, mask=subset_mask)  
                    autoregressive_mask, _ = self.autoregressiveselector(D=D_hidden_s, candidate_mask=candidate_mask, subset_mask=subset_mask, q=element_jump)
                    subset_mask = subset_mask + autoregressive_mask
                return subset_mask, candidate_mask
        else:
            raise NotImplementedError('{} not implemented in models/sss.py'.format(self.stage))
