import torch
import torch.nn as nn
import torch.nn.functional as F
from inclearn.lib import distance as distance_lib


# +
def reduce_proxies(similarities, proxy_per_class):
    # shape (batch_size, n_classes * proxy_per_class)
    n_classes = similarities.shape[1] / proxy_per_class
    assert n_classes.is_integer(), (similarities.shape[1], proxy_per_class)
    n_classes = int(n_classes)
    bs = similarities.shape[0]

    simi_per_class = similarities.view(bs, n_classes, proxy_per_class)
    attentions = F.softmax(simi_per_class, dim=-1)
    return (simi_per_class * attentions).sum(-1)
            
class LSCLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, K):
        super(LSCLinear, self).__init__()
        self.K = K
        self.out_features = out_features
        self.in_features = in_features
        
        self.weight = torch.nn.Parameter(torch.Tensor(self.K*out_features, in_features))
        
        nn.init.kaiming_normal_(self.weight, nonlinearity="linear")
        
    def forward(self, x):
        # B x out x K
        raw_similarities = -distance_lib.stable_cosine_distance(F.normalize(x, dim=-1), F.normalize(self.weight, dim=-1))
        if self.K > 1:
            similarities = reduce_proxies(raw_similarities, self.K)
        else:
            similarities = raw_similarities
        return similarities
    
class SplitLSCLinear(torch.nn.Module):
    def __init__(self, in_features, K, device):
        super(SplitLSCLinear, self).__init__()
        self.device = device
        # in_features: d, out_features1: old classes, out_features2: new classes
        self.fc1 = None #LSCLinear(in_features, out_features, K).to(self.device)
        self.fc2 = None
        self.K = K
        
        self.in_features = in_features
        self.out_features1 = None
        self.out_features2 = None
        
        self.factor = torch.nn.Parameter(torch.tensor(1.))
        
        self.to(self.device)

    def add_class(self, n_classes):
        if self.fc1 is None:
            self.fc1 = LSCLinear(self.in_features, n_classes, self.K).to(self.device)
            self.out_features1 = n_classes*self.K
            
        elif self.fc2 is None:
            self.fc1.weight.requires_grad = False
            
            self.fc2 = LSCLinear(self.in_features, n_classes, self.K).to(self.device)
            self.out_features2 = n_classes*self.K
            
        else:
            new_fc1 = LSCLinear(self.in_features, (self.out_features1+self.out_features2)//self.K, self.K).to(self.device)
            new_fc1.weight.data[:self.out_features1] = self.fc1.weight.data 
            new_fc1.weight.data[self.out_features1:] = self.fc2.weight.data
            self.fc1 = new_fc1            
            self.fc1.weight.requires_grad = False
            
            self.fc2 = LSCLinear(self.in_features, n_classes, self.K).to(self.device)
            
            self.out_features1 += self.out_features2
            self.out_features2 = n_classes*self.K
            
            
    def forward(self, x):
        out = self.fc1(x)
        if self.fc2 is not None:
            out2 = self.fc2(x)
            out = torch.cat((out, out2), dim=1) #concatenate along the channel
        
        return out


# -

class PLUMatrix(nn.Module):
    '''
    PLU decomposition for invertible matrix
    ref: https://arxiv.org/pdf/1807.03039.pdf section 3.2
    '''
    def __init__(self, dim, positive_s=False, eps=1e-8):
        super(PLUMatrix, self).__init__()

        self.positive_s = positive_s
        self.eps = eps
        self._initialize(dim)

    def _initialize(self, dim):
        w, P, L, U = self.sampling_W(dim)
        self.P = P
        self._L = nn.Parameter(L)
        self._U = nn.Parameter(torch.triu(U, diagonal=1))
        if self.positive_s:
            self.log_s = nn.Parameter(torch.log(torch.abs(torch.diag(U))))
        else:
            self.log_s = nn.Parameter(torch.diag(U))

        self.I = torch.diag(torch.ones(dim))
        return
    
    def sampling_W(self, dim):
        # sample a rotation matrix
        W = torch.empty(dim, dim)
        torch.nn.init.orthogonal_(W)
        # compute LU
        LU, pivot = torch.lu(W)
        P, L, U = torch.lu_unpack(LU, pivot)
        return W, P, L, U

    def L(self):
        # turn l to lower
        l_ = torch.tril(self._L, diagonal=-1)

        return l_ + self.I.to(self._L.device)
    
    def U(self):
        return torch.triu(self._U, diagonal=1)
    
    def W(self):
        if self.positive_s:
            s = torch.diag(torch.exp(self.log_s))
        else:
            s = torch.diag(self.log_s)
        return self.P.to(self._L.device) @ self.L() @ (self.U() + s)
    
    def inv_W(self):
        # need to be optimized based on the LU decomposition
        w = self.W()
        inv_w = torch.inverse(w)
        return inv_w
    
    def logdet(self):
        if self.positive_s:
            return torch.sum(self.log_s)
        else:
            return torch.sum(torch.log(torch.abs(self.log_s) + self.eps))

class InvertibleLinear(nn.Module):
    '''
    Invertible Linear
    ref: https://arxiv.org/pdf/1807.03039.pdf section 3.2
    '''
    def __init__(self, dim, positive_s=False, eps=1e-8):
        super(InvertibleLinear, self).__init__()
        self.mat = PLUMatrix(dim, positive_s=positive_s, eps=eps)

#         self.compute_p = True

#     def computing_p(self, b):
#         self.compute_p = b
#         for sub_m in self.modules():
#             if isinstance(sub_m, INNModule):
#                 sub_m.compute_p = b

    def logdet(self, x):
        return self.mat.logdet().repeat(x.shape[0])

    def forward(self, x):
        weight = self.mat.W()
        return F.linear(x, weight)
    
    def inverse(self, y):
        return F.linear(y, self.mat.inv_W())
