import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from .genotypes import *    # dont use in Testing this file

STEP_STEP_OPS = {
    'Sum': lambda C, L, args: Sum(),
    'Sparse-Attn': lambda C, L, args: ScaledDotAttn(C, L),
    'GLU': lambda C, L, args: NormalLinearGLU(C, args),
    'Concat': lambda C, L, args: NormalConcatFC(C, args),
    'SELayer': lambda C, L, args: SELayer(C),
    'Sparse-GLU': lambda C, L, args: LinearGLU(C, args),
    'Sparse-Concat': lambda C, L, args: ConcatFC(C, args),
    'Sparse-SELayer': lambda C, L, args: SparseSELayer(C),    
}

class NormalLinearGLU(nn.Module): # different to GMU(gated multimodal unit). here GLU dont need adaptive weight;
    def __init__(self, C, args):
        super().__init__()
        # 1x1 conv1d
        self.conv = nn.Conv1d(2*C, 2*C, 1, 1) # in channels, out channels, kernal size 1, stride 1
        # self.conv = MatrixSparseCodingLayer(n_channel=2*2*C, dict_size=2*C)
        self.bn = nn.BatchNorm1d(2*C)
        self.dropout = nn.Dropout(args.drpt)

    def forward(self, x, y):
        # concat on channels
        out = torch.cat([x, y], dim=1)
        out = self.conv(out)
        out = self.bn(out)
        #print(out.shape)
        # apply glu on channel dim
        out = F.glu(out, dim=1)
        out = self.dropout(out)
        return out

class NormalConcatFC(nn.Module):
    def __init__(self, C, args):
        super().__init__()
        # 1x1 conv1d
        self.conv = nn.Conv1d(2*C, C, 1, 1)
        # self.conv = MatrixSparseCodingLayer(n_channel=2*2*C, dict_size=C)
        self.bn = nn.BatchNorm1d(C)
        self.dropout = nn.Dropout(args.drpt)
    def forward(self, x, y):
        # concat on channels
        out = torch.cat([x, y], dim=1)
        out = self.conv(out)
        out = self.bn(out)
        out = F.relu(out)
        out = self.dropout(out)
        return out
    
class SparseSELayer(nn.Module):
    # 输入输出相同
    def __init__(self, channel, reduction=16):
        super(SparseSELayer,self).__init__()
        self.avg_pool = torch.nn.AdaptiveAvgPool1d(1) #1d for 1D, 2d for 2D
        self.fc = nn.Sequential(
            nn.Linear(2*channel, 2*channel // reduction, bias=False),
            nn.ReLU(inplace = True),
            nn.Linear(2*channel // reduction, 2*channel, bias=False), #240920 how to cope with 2C and C in this layer?
            nn.Sigmoid()
            )
        # self.conv1x1 = nn.Conv1d(2*channel, channel, kernel_size=1)
        self.conv1x1 = MatrixSparseCodingLayer(n_channel=2*2*channel, dict_size=channel)
    def forward(self, x, y):
        out = torch.cat([x, y], dim=1)
        b, c, _ = out.size() # For 2D input: b, c, _, _ = x.size() #
        z = self.avg_pool(out).view(b, c)
        z = self.fc(z).view(b, c, 1) # For 2D input: y = self.fc(y).view(b, c, 1, 1)
        out_ = out * z.expand_as(out)
        out_ = self.conv1x1(out_,out_)
        return out_

class SeparableConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(SeparableConv1d,self).__init__()
        self.depthwise = nn.Conv1d(2*in_channels, 2*in_channels, kernel_size, stride, padding=1, dilation=dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv1d(2*in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x, y):
        out = torch.cat([x, y], dim=1)
        out = self.depthwise(out)
        out = self.pointwise(out)
        return out

class Sum(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x, y):
        return x + y

class LinearGLU(nn.Module):
    def __init__(self, C, args):
        super().__init__()
        # 1x1 conv1d
        # self.conv = nn.Conv1d(2*C, 2*C, 1, 1) # in channels, out channels, kernal size 1, stride 1
        self.conv = MatrixSparseCodingLayer(n_channel=2*2*C, dict_size=2*C)
        self.bn = nn.BatchNorm1d(2*C)
        self.dropout = nn.Dropout(args.drpt)

    def forward(self, x, y):
        # concat on channels
        out = torch.cat([x, y], dim=1)
        out = self.conv(out,out)
        out = self.bn(out)
        #print(out.shape)
        # apply glu on channel dim
        out = F.glu(out, dim=1)
        out = self.dropout(out)
        return out

class ConcatFC(nn.Module):
    def __init__(self, C, args):
        super().__init__()
        # 1x1 conv1d
        # self.conv = nn.Conv1d(2*C, C, 1, 1)
        self.conv = MatrixSparseCodingLayer(n_channel=2*2*C, dict_size=C)
        self.bn = nn.BatchNorm1d(C)
        self.dropout = nn.Dropout(args.drpt)
    def forward(self, x, y):
        # concat on channels
        out = torch.cat([x, y], dim=1)
        out = self.conv(out,out)
        out = self.bn(out)
        out = F.relu(out)
        out = self.dropout(out)
        return out

class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        out = x * (torch.tanh(F.softplus(x)))
        return out 

class CatConvMish(nn.Module):
    def __init__(self, C, args):
        super().__init__()
        # 1x1 conv1d
        # self.conv = nn.Conv1d(2*C, C, 1, 1)
        self.conv = MatrixSparseCodingLayer(n_channel=2*2*C, dict_size=C)
        self.bn = nn.BatchNorm1d(C)
        self.dropout = nn.Dropout(args.drpt)
        self.mish = Mish()

    def forward(self, x, y):
        # concat on channels
        out = torch.cat([x, y], dim=1)
        out = self.conv(out,out)
        out = self.bn(out)
        out = self.mish(out)
        out = self.dropout(out)
        return out

class ScaledDotAttn(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, C, L):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.ln = nn.LayerNorm([C, L])

    def forward(self, x, y):
        # trans pose C to last dim
        q = x.transpose(1, 2)
        k = y
        v = y.transpose(1, 2)
        
        d_k = q.size(-1)
        scores = torch.matmul(q, k) / math.sqrt(d_k)

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)

        out = out.transpose(1, 2)
        out = self.dropout(out)
        out = self.ln(out)

        return out

class NodeMixedOp(nn.Module):
    def __init__(self, C, L, args):
        super().__init__()
        self._ops = nn.ModuleList()
        for primitive in STEP_STEP_PRIMITIVES:
            op = STEP_STEP_OPS[primitive](C, L, args)
            self._ops.append(op)
        #print('self._ops:',self._ops)

    def forward(self, x, y, weights):
        # print('xxx:',x, x.shape)
        # print('yyy:',y, y.shape)
        # print('weights:',weights)
        # for w, op in zip(weights, self._ops):
        #    print('weights and operations:', w, op)
        #    print('test op(x,y)', op(x,y), op(x,y).shape)
        out = sum(w * op(x, y) for w, op in zip(weights, self._ops))
        return out
    

### 240723 add
# Expand the ops from [sum, ScaleDotAttn, LinearGLU, ConcatFC] to [sum, ScaleDotAttn, LinearGLU, ConcatFC, | Relu (or) CatconvMish, SELayer, MSC_Layer, Depthwise Conv] 
### 240920 imple
# x and y are the inputs of two different modal, so each operation need to combine them.
# inspired by 'class LinearGLU', try to use 'torch.cat([x,y], dim=1)'
class Relu(nn.Module):
    def __init__(self, C, args):
        super().__init__()
        # 1x1 conv1d
        self.conv = nn.Conv1d(2*C, 2*C, 1, 1)
        # self.conv = MatrixSparseCodingLayer(n_channel=2*C, dict_size=C)
        self.bn = nn.BatchNorm1d(2*C)
        self.dropout = nn.Dropout(args.drpt)

    def forward(self, x, y):
        # concat on channels
        out = torch.cat([x, y], dim=1)
        out = self.conv(out)
        out = self.bn(out)
        #print(out.shape)
        out = F.relu(out, dim=1)
        out = self.dropout(out)
        return out


class SELayer(nn.Module):
    # 输入输出相同
    def __init__(self, channel, reduction=16):
        super(SELayer,self).__init__()
        self.avg_pool = torch.nn.AdaptiveAvgPool1d(1) #1d for 1D, 2d for 2D
        self.fc = nn.Sequential(
            nn.Linear(2*channel, 2*channel // reduction, bias=False),
            nn.ReLU(inplace = True),
            nn.Linear(2*channel // reduction, 2*channel, bias=False), #240920 how to cope with 2C and C in this layer?
            nn.Sigmoid()
            )
        self.conv1x1 = nn.Conv1d(2*channel, channel, kernel_size=1)
    def forward(self, x, y):
        out = torch.cat([x, y], dim=1)
        b, c, _ = out.size() # For 2D input: b, c, _, _ = x.size() #
        z = self.avg_pool(out).view(b, c)
        z = self.fc(z).view(b, c, 1) # For 2D input: y = self.fc(y).view(b, c, 1, 1)
        out_ = out * z.expand_as(out)
        out_ = self.conv1x1(out_)
        return out_

class SeparableConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(SeparableConv1d,self).__init__()
        self.depthwise = nn.Conv1d(2*in_channels, 2*in_channels, kernel_size, stride, padding=1, dilation=dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv1d(2*in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x, y):
        out = torch.cat([x, y], dim=1)
        out = self.depthwise(out)
        out = self.pointwise(out)
        return out
    
    
import torch._utils
import torch.nn.init as init
# from Lib.config import config  ##dont need in testing msc_layer.py
class elasnet_prox(nn.Module):
    r"""Applies the elastic net proximal operator,
    NOTS: it will degenerate to ell1_prox if mu=0.0

    The elastic net proximal operator function is given as the following function
    \argmin_{x} \lambda ||x||_1 + \mu /2 ||x||_2^2 + 0.5 ||x - input||_2^2

    Args:
      lambd: the :math:`\lambda` value on the ell_1 penalty term. Default: 0.5
      mu:    the :math:`\mu` value on the ell_2 penalty term. Default: 0.0

    Shape:
      - Input: :math:`(N, *)` where `*` means, any number of additional
        dimensions
      - Output: :math:`(N, *)`, same shape as the input

    """

    def __init__(self, lambd=0.5, mu=0.0):
        super(elasnet_prox, self).__init__()
        self.lambd = lambd
        self.scaling_mu = 1.0 / (1.0 + mu)

    def forward(self, input):
        return F.softshrink(input * self.scaling_mu, self.lambd * self.scaling_mu)

    def extra_repr(self):
        return '{} {}'.format(self.lambd, self.scaling_mu)


class MatrixSparseCodingLayer(nn.Module): # 'print' has been annotated 240922 

# c = argmin_c lmbd * ||c||_1  +  mu/2 * ||c||_2^2 + 1 / 2 * ||x - weight @ c||_2^2
    def __init__(self, n_channel, dict_size, mu=0.0, lmbd=0.1, n_dict=1, non_negative=True, n_steps=2, # model parameters
                 square_noise=True,  # optional model parameters
                 step_size=0.1, w_norm=True):  # training parameters
        super(MatrixSparseCodingLayer, self).__init__()

        self.mu = mu
        self.lmbd = lmbd  # LAMBDA
        self.n_dict = n_dict
        self.dict_size = dict_size
        self.n_channel = n_channel
        self.n_steps = n_steps
        self.w_norm = w_norm
        self.non_negative = non_negative
        self.v_max = None
        self.v_max_error = 0.
        # c = argmin_c lmbd * ||c||_1  +  mu/2 * ||c||_2^2 + 1 / 2 * ||x - weight @ c||_1 // if square noise is False
        self.square_noise = square_noise  #

        self.weight = nn.Parameter(torch.Tensor(dict_size, n_channel * self.n_dict))

        with torch.no_grad():
            init.kaiming_uniform_(self.weight)

        # variables that are needed for ISTA/FISTA
        self.nonlinear = elasnet_prox(self.lmbd * step_size, self.mu * step_size)

        self.register_buffer('step_size', torch.tensor(step_size, dtype=torch.float))

    def fista(self, x):

        for i in range(self.n_steps):

            weight = self.weight
            step_size = self.step_size
            #print('step_size', step_size)

            if i == 0:
                c_pre = 0.
                # x [bs, n_channel * n_dict]; weight [dict_size, n_channel * n_dict]
                # where F.linear(x, A) execute the operator xA^T

                #print('x.repeat', x.repeat(1, self.n_dict), x.repeat(1, self.n_dict).shape)
                #print('weight_fista', weight.detach(), weight.detach().shape) # this is a nn.Parameter, so use .detach() to observe

                c = step_size * F.linear(x.repeat(1, self.n_dict), weight, bias=None)  # c [bs, dict_size]
                #print('c.shape', c.shape)
                c = self.nonlinear(c)

            elif i == 1:
                c_pre = c
                #print('c_pre.shape', c_pre.shape)
                # weight = self.normalize(weight)
                xp = F.linear(c, weight.T, bias=None)  # xp [bs, n_channel * n_dict]
                #print('xp.shape', xp.shape)
                r = x.repeat(1, self.n_dict) - xp

                if self.square_noise:
                    gra = F.linear(r, weight, bias=None)
                else:

                    # w = r.view(r.size(0), -1)
                    # normw = w.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12).detach()
                    # w = w / normw
                    r = r / r.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12).detach()

                    gra = F.linear(r, weight) * 0.5

                c = c + step_size * gra
                c = self.nonlinear(c)
                t = (math.sqrt(5.0) + 1.0) / 2.0

            else:
                t_pre = t
                t = (math.sqrt(1.0 + 4.0 * t_pre * t_pre) + 1) / 2.0
                a = (t_pre + t - 1.0) / t * c + (1.0 - t_pre) / t * c_pre
                c_pre = c
                # weight = self.normalize(weight)
                xp = F.linear(c, weight.T, bias=None)
                r = x.repeat(1, self.n_dict) - xp

                if self.square_noise:
                    gra = F.linear(r, weight, bias=None)
                else:
                    r = r / r.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12).detach()
                    gra = F.linear(r, weight) * 0.5

                c = a + step_size * gra
                c = self.nonlinear(c)

            if self.non_negative:
                c = F.relu(c)

            # self.c_error.append(torch.sum((c) ** 2) / c.shape[0])
        return c, weight
    

    def forward(self, x, y):

        '''
        :param x: [bs, token, dim] (for transformer) or [bs, dim] (for MLP)
        :return:
        '''
        out = torch.cat([x, y], dim=1)
        #print('out.shape',out.shape)

        if out.dim() == 2:
            bs, dim = out.shape
            out_ = out
            #print('bs, dim', bs, dim)
        # elif out.dim() == 3:
        #     bs, token, dim = out.shape
        #     out_ = out.view(-1, dim)
        #     print('bs, token, dim', bs, token, dim)
        #     print('out.view(-1,dim)',out_.shape)
        elif out.dim() == 3:    #240922 rewrite to cope with inputsize (8,192,16) channel192
            bs = out.shape[0]
            dim = out.shape[1]
            seq_len = out.shape[2]
            out_flattened = out.permute(0,2,1).reshape(-1,dim) # both of .reshape or .contiguous().view is OK
            # the flattened shape is(bs * seq_len, dim)
            out_ = out_flattened
        else:
            raise NotImplementedError

        # if self.training: #240920: remove the series of judgement
        #     self.update_stepsize()  #why need move this?
        #     if torch.cuda.device_count() > 1:
        #         raise ValueError(
        #             "would be cause conflict of number in Dataparallel!! "
        #             "Move the update_stepsize before each feedforward of Dataparallel"
        #         )

        if self.w_norm: #240920: use "self.w_norm" to replace "self.w_norm and self.training"
            self.normalize_weight()

        c, weight = self.fista(out_)

        # Compute loss
        xp = F.linear(c, weight.T, bias=None)
        r = out_.repeat(1, self.n_dict) - xp
        r_loss = torch.sum(torch.pow(r, 2)) / self.n_dict
        c_loss = self.lmbd * torch.sum(torch.abs(c)) + self.mu / 2. * torch.sum(torch.pow(c, 2))

        if out.dim() == 2:
            pass
        # elif out.dim() == 3:
        #     c = c.view(bs, token, -1)
        #     xp = xp.view(bs, token, -1)
        elif out.dim() == 3:    #240922 rewrite to cope with inputsize(8,192,16)
            #out_ = out.flattened.view(bs, seq_len, -1).permute(0,2,1)  # if need to use raw shape
            c = c.view(bs, -1, seq_len)
            xp = xp.view(bs, -1, seq_len)
        else:
            raise NotImplementedError

        #print('final c.shape: ', c.shape)
        #print('final xp.shape: ', xp.shape)

        # return c, xp, r, (r_loss, c_loss) # except for c, the other variable are for computing loss 
        return c

    def update_stepsize(self):
        step_size = 0.9 / self.power_iteration(self.weight)
        self.step_size = self.step_size * 0. + step_size
        self.nonlinear.lambd = self.lmbd * step_size
        self.nonlinear.scaling_mu = 1.0 / (1.0 + self.mu * step_size)

    def normalize_weight(self):
        with torch.no_grad():
            w = self.weight.view(self.weight.size(0), -1)
            normw = w.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12)
            w = w / normw
            self.weight.data = w.data

    def power_iteration(self, weight):

        max_iteration = 50
        v_max_error = 1.0e5
        tol = 1.0e-5
        k = 0

        with torch.no_grad():
            if self.v_max is None:
                v = torch.randn(size=(1, self.dict_size)).to(weight.device)
            else:
                v = self.v_max.clone()

            while k < max_iteration and v_max_error > tol:

                tmp = F.linear(v, weight.T, bias=None)
                v_ = F.linear(tmp, weight, bias=None)
                v_ = F.normalize(v_.view(-1), dim=0, p=2).view(v.size())
                v_max_error = torch.sum((v_ - v) ** 2)
                k += 1
                v = v_

            v_max = v.clone()
            Dv_max = F.linear(v_max, weight.T, bias=None)  # Dv
            lambda_max = torch.sum(Dv_max ** 2).item()  # vTDTDv / vTv, ignore the vTv since vTv = 1

        self.v_max = v_max
        return lambda_max
    
if __name__ == '__main__': # for test class MatrixSparseCodingLayer

    '''# code for single input is OK
    # use case
    # ------test each Layer------
    print("------test MSCLayer------")
    input_dim = 32
    output_dim = 64
    layer1 = MatrixSparseCodingLayer(n_channel=input_dim, dict_size=output_dim)

    x = torch.randn(1, 32)

    y = layer1(x)

    print(y)
    
    print("------test SELayer------")
    input_tensor = torch.rand(8,64,128) #batch_size=8, channels=64, length=128
    se_block = SELayer(channel=64, reduction=16)
    print(f"Input Shape: {input_tensor.shape}")
    output_tensor = se_block(input_tensor)
    print(f"Output Shape: {output_tensor.shape}")

    print("------test SeparableConv1d-------")
    x = torch.rand(2, 3, 10) #batch_size=2, channels=3, length=10
    separable_conv = SeparableConv1d(in_channels=3, out_channels=6, kernel_size=3, padding=1)
    output = separable_conv(x)
    print(output.shape)
    '''
    '''# code for two inputs'''
    # use case
    # ------test each Layer------
    print("------test MSCLayer------")
    print("#------2D input------#")
    input_dim = 32
    output_dim = 64
    layer1 = MatrixSparseCodingLayer(n_channel=2*input_dim, dict_size=2*output_dim)
    x1 = torch.randn(8, 32)
    x2 = torch.randn(8, 32)
    
    y = layer1(x1,x2)
    print(y,y.shape)

    print("#------3D input------")
    input_dim = 192
    output_dim = 2*192
    layer1 = MatrixSparseCodingLayer(n_channel=2*input_dim, dict_size=2*output_dim)
    x1 = torch.randn(8, 192, 16)
    x2 = torch.randn(8, 192, 16)

    y = layer1(x1,x2)
    print(y,y.shape)
    
    print("------test SELayer------")
    input_tensor1 = torch.rand(8,64,128) #batch_size=8, channels=64, length=128
    input_tensor2 = torch.rand(8,64,128) #batch_size=8, channels=64, length=128
    se_block = SELayer(channel=64, reduction=16)
    output_tensor = se_block(input_tensor1, input_tensor2)
    print(f"Output Shape: {output_tensor.shape}")

    print("------test SeparableConv1d-------")
    x1 = torch.rand(2, 3, 10) #batch_size=2, channels=3, length=10
    x2 = torch.rand(2, 3, 10) #batch_size=2, channels=3, length=10
    separable_conv = SeparableConv1d(in_channels=3, out_channels=6, kernel_size=3, padding=1)
    output = separable_conv(x1, x2)
    print(output.shape)
