'''
Reference:
https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py
'''
import math
import torch
from torch import nn
from torch.nn import functional as F
from typing import Union, Optional, Dict
from abc import ABCMeta, abstractmethod
import torch.nn.init as init
import math
import einops as ein

class DNMLayer(nn.Module):
    """
    LayerNorm
    s = x*w +b
    x = sigmoid(s)
    x = sum(x) # sum(x1--xn)
    x = sum(x) # sum(M1--Mn)

    """
    def __init__(self, in_features, out_features, args):
        super(DNMLayer, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.num_branch = args["num_branch"]
               
        
        # synapse init
        if args["sn"]:
            self.sn0 = nn.LayerNorm(in_features)
        else:
            self.sn0 = None
        
        self.sw = nn.Parameter(torch.empty((out_features, self.num_branch, in_features)))
        self.sb = nn.Parameter(torch.empty((self.num_branch, in_features)))
        # self.dw = nn.Parameter(torch.rand((self.num_branch, 1)))


        if args["dn"]:
            self.dn = nn.LayerNorm(self.num_branch)
        else:
            self.dn = None

        self.set_activations(args)
        self.reset_parameters()


    def set_activations(self,args):

        self.sa = None
        self.da = None
        self.soma = None
        if args["synapse_activation"] == 'none':
            self.sa = None
        elif args["synapse_activation"] == 'relu':
            self.sa = nn.ReLU()
        elif args["synapse_activation"] == 'sigmoid':
            self.sa = nn.Sigmoid()
        elif args["synapse_activation"] == 'tanh':
            self.soma = nn.Tanh()
        else:
            raise "Please set synapse activation: none; relu; sigmoid"
        
        if args["dendritic_activation"]  == 'none':
            self.da = None
        elif args["dendritic_activation"] == 'relu':
            self.da = nn.ReLU()
        elif args["dendritic_activation"] == 'sigmoid':
            self.da = nn.Sigmoid()
        else:
            raise "Please set dendritic activation: none; relu; sigmoid"
        
        if args["soma"]  == 'none':
            self.soma = None
        elif args["soma"] == 'relu':
            self.soma = nn.ReLU()
        elif args["soma"] == 'sigmoid':
            self.soma = nn.Sigmoid()
        elif args["soma"] == 'tanh':
            self.soma = nn.Tanh()
        else:
            raise "Please set soma activation: none; relu; sigmoid"
        
        
        
    def forward(self, x):
        # input shape (b, in_channel), output shape (b, out_channel)
        # print(x.shape)
        
        if self.sn0 is not None:
            x = self.sn0(x)  # 0. norm


        if len(x.shape) == 2:
            b, _ = x.shape
            x = ein.repeat(x, 'b d -> b o m d', o=self.out_features, m=self.num_branch)
        elif len(x.shape) == 3:
            b, _, _ = x.shape
            x = ein.repeat(x, 'b m d -> b m d o', o=self.out_features)
        else:
            b = 0
            raise 'Please check input data'

        # synapse (norm -> wx + b (element-wise) -> norm -> activation)

        sw = ein.repeat(self.sw, 'o m d -> b o m d', b=b)
        sb = self.sb
        # sb = ein.repeat(self.sb, 'o m d -> b o m d', b=b)
        #                   # 1. repeat input data and weight

                
        x = sw * x  # 2. wx + b
        x = (x + sb)

        if self.sa is not None:
            x = self.sa(x)  # 4. activation
        sa_x = x

        x = x.sum(dim=3)  # 1. each branch sum (b o m d -> b o m)


        if self.da is not None:
            x = self.da(x)  # 2. activation
            
        if self.dn is not None:
            x = self.dn(x)  # 0. norm       

        x = x.sum(dim=2)  # 0``. each dnm cell sum (b o m -> b o)


        # soma
        if self.soma is not None:
            x = self.soma(x)


        return {'logits': x,
                'sa_x':sa_x}

    def reset_parameters(self):
        init.kaiming_uniform_(self.sw, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.sw)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        init.uniform_(self.sb, -bound, bound)
        

    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, out_features={self.out_features},num_branch={self.num_branch}, '
                f'synapse_activation={self.sa}, dendritic_activation={self.da}, soma={self.soma}, sn={self.sn0 is not None}, '
                f'dn={self.dn is not None}'
                )
    

   
    
class DNMLayer1(nn.Module):
    """
    LayerNorm
    s = x*w +b
    x = sigmoid(s)
    x = sum(x) # sum(x1--xn)
    RMSNorm
    x = prob(x) # prob(M1--Mn)


    """
    def __init__(self, in_features, out_features, num_branch=1, synapse_activation=nn.Sigmoid,
                 dendritic_activation=nn.Sigmoid, soma=nn.ReLU,bias=True,gain=True):
        super(DNMLayer1, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.num_branch = num_branch

        # synapse init: parameters
        self.weight = nn.Parameter(torch.empty((out_features//num_branch, in_features//num_branch  )))

        if bias:
            self.bias = nn.Parameter(torch.empty((out_features//num_branch)))
        
        self.synapse_activation = synapse_activation
        
        # dendritic init
        if gain:
            self.gain = nn.Parameter(torch.empty((num_branch)))
        self.dendritic_activation = dendritic_activation
        


        self.soma = soma
        self.reset_parameters()

    def forward(self, x):
        # input shape (b, in_channel), output shape (b, out_channel)
        B, N = x.shape # b, input size(multiResNet)
        x = x.reshape(B,self.num_branch,N//self.num_branch) # [B,N] --> [B, M, N//M]
        y = torch.zeros(B,self.out_features)
        for i in range(self.num_branch):
            y[:,(i-1)*(self.out_features//self.num_branch):i*(self.out_features//self.num_branch)] = F.linear(x[:,i,:], self.weight, self.bias) # [B,M, N/M] --> [B, M, Out]

        for i in range(1,self.num_branch):
            # y[:,i,:] = y[:,i,:] + self.synapse_activation()(y[:,i-1,:])
            y[:,i,:] = (y[:,i,:] + (y[:,i-1,:]))# *self.gain[i]
        
        # if self.synapse_activation is not None:
        #     y[:,:-1,:] = self.synapse_activation()(y[:,:-1,:])

        # for i in range(self.num_branch):       
        # out =  y[:,-1,:]+ torch.sum(torch.multipl(y[:,:-1,:],self.gain),dim=1) # y = (y_m + sigma (y_1 -- y_m-1))
        out =  y[:,-1,:] # + torch.sum(y[:,:-1,:],dim=1) # y = (y_m + sigma (y_1 -- y_m-1))


        out = torch.sigmoid(out)
       

        return {'logits': out}

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        
        if self.bias is not None:
            init.uniform_(self.bias, -bound, bound)
        if self.gain is not None:
            init.uniform_(self.gain, -1, 1)
        # init.kaiming_uniform_(self.sb, a=math.sqrt(5))

    def extra_repr(self) -> str:
        #in_features, out_features, num_branch=1, synapse_activation=nn.Sigmoid, dendritic_activation=nn.Sigmoid, soma=nn.ReLU,bias=False,gain=False
        return (f'in_features={self.in_features}, out_features={self.out_features},num_branch={self.num_branch},'
                f'synapse_activation={self.synapse_activation}, dendritic_activation={self.dendritic_activation}'
                f'bias={self.bias is not None},gain={self.gain is not None} ')
    
        
    
class DNMLayer2(nn.Module):
    """
    LayerNorm
    s = x*w +b
    x = sigmoid(s)
    x = sum(x) # sum(x1--xn)
    RMSNorm
    x = prob(x) # prob(M1--Mn)


    """
    def __init__(self, in_features, out_features, num_branch=1, synapse_activation=nn.Sigmoid,
                 dendritic_activation=nn.Sigmoid, soma=nn.ReLU,bias=True,gain=True):
        super(DNMLayer1, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.num_branch = num_branch

        # synapse init: parameters
        self.weight = nn.Parameter(torch.empty((out_features//num_branch, in_features//num_branch  )))

        if bias:
            self.bias = nn.Parameter(torch.empty((out_features//num_branch)))
        
        self.synapse_activation = synapse_activation
        
        # dendritic init
        if gain:
            self.gain = nn.Parameter(torch.empty((num_branch)))
        self.dendritic_activation = dendritic_activation
        


        self.soma = soma
        self.reset_parameters()

    def forward(self, x):
        # input shape (b, in_channel), output shape (b, out_channel)
        B, N = x.shape # b, input size(multiResNet)
        x = x.reshape(B,self.num_branch,N//self.num_branch) # [B,N] --> [B, M, N//M]
        y = torch.zeros(B,self.out_features)
        for i in range(self.num_branch):
            y[:,(i-1)*(self.out_features//self.num_branch):i*(self.out_features//self.num_branch)] = F.linear(x[:,i,:], self.weight, self.bias) # [B,M, N/M] --> [B, M, Out]

        for i in range(1,self.num_branch):
            # y[:,i,:] = y[:,i,:] + self.synapse_activation()(y[:,i-1,:])
            y[:,i,:] = (y[:,i,:] + (y[:,i-1,:]))# *self.gain[i]
        
        # if self.synapse_activation is not None:
        #     y[:,:-1,:] = self.synapse_activation()(y[:,:-1,:])

        # for i in range(self.num_branch):       
        # out =  y[:,-1,:]+ torch.sum(torch.multipl(y[:,:-1,:],self.gain),dim=1) # y = (y_m + sigma (y_1 -- y_m-1))
        out =  y[:,-1,:] # + torch.sum(y[:,:-1,:],dim=1) # y = (y_m + sigma (y_1 -- y_m-1))


        out = torch.sigmoid(out)
       

        return {'logits': out}

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        
        if self.bias is not None:
            init.uniform_(self.bias, -bound, bound)
        if self.gain is not None:
            init.uniform_(self.gain, -1, 1)
        # init.kaiming_uniform_(self.sb, a=math.sqrt(5))

    def extra_repr(self) -> str:
        #in_features, out_features, num_branch=1, synapse_activation=nn.Sigmoid, dendritic_activation=nn.Sigmoid, soma=nn.ReLU,bias=False,gain=False
        return (f'in_features={self.in_features}, out_features={self.out_features},num_branch={self.num_branch},'
                f'synapse_activation={self.synapse_activation}, dendritic_activation={self.dendritic_activation}'
                f'bias={self.bias is not None},gain={self.gain is not None} ')
    
        


class SimpleLinear(nn.Module):
    '''
    Reference:
    https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
    '''
    def __init__(self, in_features, out_features, bias=True):
        super(SimpleLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, nonlinearity='linear')
        nn.init.constant_(self.bias, 0)

    def forward(self, input):
        # print(self.weight.mean().item(),self.weight.std().item(),self.weight.min().item(),self.weight.max().item())
        return {'logits': F.linear(input, self.weight, self.bias)}
    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, out_features={self.out_features} ')

class TagFex_SimpleLinear(nn.Module):
    '''
    Reference:
    https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py
    '''
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TagFex_SimpleLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features, **factory_kwargs))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, nonlinearity='linear')
        if self.bias is not None:
            nn.init.zeros_(self.bias)

    def forward(self, input) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self) -> str:
        return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'

class CosineLinear(nn.Module):
    def __init__(self, in_features, out_features, nb_proxy=1, to_reduce=False, sigma=True):
        super(CosineLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features * nb_proxy
        self.nb_proxy = nb_proxy
        self.to_reduce = to_reduce
        self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))
        if sigma:
            self.sigma = nn.Parameter(torch.Tensor(1))
        else:
            self.register_parameter('sigma', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.sigma is not None:
            self.sigma.data.fill_(1)

    def forward(self, input):
        out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))

        if self.to_reduce:
            # Reduce_proxy
            out = reduce_proxies(out, self.nb_proxy)

        if self.sigma is not None:
            out = self.sigma * out

        return {'logits': out}

    def extra_repr(self) -> str:
        return (f'in_features={self.in_features}, out_features={self.out_features} ')


class SplitCosineLinear(nn.Module):
    def __init__(self, in_features, out_features1, out_features2, nb_proxy=1, sigma=True):
        super(SplitCosineLinear, self).__init__()
        self.in_features = in_features
        self.out_features = (out_features1 + out_features2) * nb_proxy
        self.nb_proxy = nb_proxy
        self.fc1 = CosineLinear(in_features, out_features1, nb_proxy, False, False)
        self.fc2 = CosineLinear(in_features, out_features2, nb_proxy, False, False)
        if sigma:
            self.sigma = nn.Parameter(torch.Tensor(1))
            self.sigma.data.fill_(1)
        else:
            self.register_parameter('sigma', None)

    def forward(self, x):
        out1 = self.fc1(x)
        out2 = self.fc2(x)

        out = torch.cat((out1['logits'], out2['logits']), dim=1)  # concatenate along the channel

        # Reduce_proxy
        out = reduce_proxies(out, self.nb_proxy)

        if self.sigma is not None:
            out = self.sigma * out

        return {
            'old_scores': reduce_proxies(out1['logits'], self.nb_proxy),
            'new_scores': reduce_proxies(out2['logits'], self.nb_proxy),
            'logits': out
        }


class AnalyticLinear(torch.nn.Linear, metaclass=ABCMeta):
    """
    Abstract linear module for the analytic continual learning [1-3].

    This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning.

    References:
    [1] Zhuang, Huiping, et al.
        "ACIL: Analytic class-incremental learning with absolute memorization and privacy protection."
        Advances in Neural Information Processing Systems 35 (2022): 11602-11614.
    [2] Zhuang, Huiping, et al.
        "GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task."
        Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.
    [3] Zhuang, Huiping, et al.
        "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning."
        Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.
    """

    def __init__(
        self,
        in_features: int,
        gamma: float = 1e-1,
        bias: bool = False,
        device: Optional[Union[torch.device, str, int]] = None,
        dtype=torch.double,
    ) -> None:
        super(torch.nn.Linear, self).__init__()  # Skip the Linear class
        factory_kwargs = {"device": device, "dtype": dtype}
        self.gamma: float = gamma
        self.bias: bool = bias
        self.dtype = dtype

        # Linear Layer
        if bias:
            in_features += 1
        weight = torch.zeros((in_features, 0), **factory_kwargs)
        self.register_buffer("weight", weight)

    @torch.no_grad()
    def forward(self, X: torch.Tensor) -> Dict[str, torch.Tensor]:
        X = X.to(self.weight)
        if self.bias:
            X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1)
        return {"logits": X @ self.weight}

    @property
    def in_features(self) -> int:
        if self.bias:
            return self.weight.shape[0] - 1
        return self.weight.shape[0]

    @property
    def out_features(self) -> int:
        return self.weight.shape[1]

    def reset_parameters(self) -> None:
        # Following the equation (4) of ACIL, self.weight is set to \hat{W}_{FCN}^{-1}
        self.weight = torch.zeros((self.weight.shape[0], 0)).to(self.weight)

    @abstractmethod
    def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None:
        raise NotImplementedError()

    def after_task(self) -> None:
        assert torch.isfinite(self.weight).all(), (
            "Pay attention to the numerical stability! "
            "A possible solution is to increase the value of gamma. "
            "Setting self.dtype=torch.double also helps."
        )


class RecursiveLinear(AnalyticLinear):
    """
    Recursive analytic linear (ridge regression) modules for the analytic continual learning [1-3].

    This implementation refers to the official implementation https://github.com/ZHUANGHP/Analytic-continual-learning.

    References:
    [1] Zhuang, Huiping, et al.
        "ACIL: Analytic class-incremental learning with absolute memorization and privacy protection."
        Advances in Neural Information Processing Systems 35 (2022): 11602-11614.
    [2] Zhuang, Huiping, et al.
        "GKEAL: Gaussian Kernel Embedded Analytic Learning for Few-Shot Class Incremental Task."
        Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2023.
    [3] Zhuang, Huiping, et al.
        "DS-AL: A Dual-Stream Analytic Learning for Exemplar-Free Class-Incremental Learning."
        Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 38. No. 15. 2024.
    """

    def __init__(
        self,
        in_features: int,
        gamma: float = 1e-1,
        bias: bool = False,
        device: Optional[Union[torch.device, str, int]] = None,
        dtype=torch.double,
    ) -> None:
        super().__init__(in_features, gamma, bias, device, dtype)
        factory_kwargs = {"device": device, "dtype": dtype}

        # Regularized Feature Autocorrelation Matrix (RFAuM)
        self.R: torch.Tensor
        R = torch.eye(self.weight.shape[0], **factory_kwargs) / self.gamma
        self.register_buffer("R", R)

    def update_fc(self, nb_classes: int) -> None:
        increment_size = nb_classes - self.out_features
        assert increment_size >= 0, "The number of classes should be increasing."
        tail = torch.zeros((self.weight.shape[0], increment_size)).to(self.weight)
        self.weight = torch.cat((self.weight, tail), dim=1)

    @torch.no_grad()
    def fit(self, X: torch.Tensor, Y: torch.Tensor) -> None:
        """The core code of the ACIL [1].
        This implementation, which is different but equivalent to the equations shown in the paper,
        which supports mini-batch learning.
        """
        X, Y = X.to(self.weight), Y.to(self.weight)
        if self.bias:
            X = torch.cat((X, torch.ones(X.shape[0], 1).to(X)), dim=-1)

        # ACIL
        # Please update your PyTorch & CUDA if the `cusolver error` occurs.
        # If you insist on using this version, doing the `torch.inverse` on CPUs might help.
        # >>> K_inv = torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T
        # >>> K = torch.inverse(K_inv.cpu()).to(self.weight.device)
        K = torch.inverse(torch.eye(X.shape[0]).to(X) + X @ self.R @ X.T)
        # Equation (10) of ACIL
        self.R -= self.R @ X.T @ K @ X @ self.R
        # Equation (9) of ACIL
        self.weight += self.R @ X.T @ (Y - X @ self.weight)


def reduce_proxies(out, nb_proxy):
    if nb_proxy == 1:
        return out
    bs = out.shape[0]
    nb_classes = out.shape[1] / nb_proxy
    assert nb_classes.is_integer(), 'Shape error'
    nb_classes = int(nb_classes)

    simi_per_class = out.view(bs, nb_classes, nb_proxy)
    attentions = F.softmax(simi_per_class, dim=-1)

    return (attentions * simi_per_class).sum(-1)


'''
class CosineLinear(nn.Module):
    def __init__(self, in_features, out_features, sigma=True):
        super(CosineLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if sigma:
            self.sigma = nn.Parameter(torch.Tensor(1))
        else:
            self.register_parameter('sigma', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.sigma is not None:
            self.sigma.data.fill_(1)

    def forward(self, input):
        out = F.linear(F.normalize(input, p=2, dim=1), F.normalize(self.weight, p=2, dim=1))
        if self.sigma is not None:
            out = self.sigma * out
        return {'logits': out}


class SplitCosineLinear(nn.Module):
    def __init__(self, in_features, out_features1, out_features2, sigma=True):
        super(SplitCosineLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features1 + out_features2
        self.fc1 = CosineLinear(in_features, out_features1, False)
        self.fc2 = CosineLinear(in_features, out_features2, False)
        if sigma:
            self.sigma = nn.Parameter(torch.Tensor(1))
            self.sigma.data.fill_(1)
        else:
            self.register_parameter('sigma', None)

    def forward(self, x):
        out1 = self.fc1(x)
        out2 = self.fc2(x)

        out = torch.cat((out1['logits'], out2['logits']), dim=1)  # concatenate along the channel
        if self.sigma is not None:
            out = self.sigma * out

        return {
            'old_scores': out1['logits'],
            'new_scores': out2['logits'],
            'logits': out
        }
'''
