import torch 
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
from aggutils import *
from user_init import user_normal_



class Linear_iea_nonlinear_nom(nn.Module):
    def __init__(self, in_features, out_features,bias=True, m=3,nonlinea=F.tanh):
        super().__init__()
        self.mms = nn.ModuleList(
        [nn.Linear(in_features,out_features, bias=bias)
         for i in range(m)]
        )
        self.nonline = nonlinea
        self.m = m
        self.minv = nn.ParameterList([ nn.Parameter(torch.rand(1).fill_(1.0), requires_grad=False) for i in range(self.m) ])
        
        self.sublin = nn.Linear(in_features,out_features, bias=bias)
        self.domms = True
        self.bias = bias
        
        # self.m_pairs = [] 
        # for i in range(m):
            # for j in range(i,m):
                # if i != j: 
                    # self.m_pairs.append((i,j))
                    
    def calc_loss(self):
        orth_l = orth_loss(self.mms[0].weight)
        for i in range(1,self.m):
            orth_l+= orth_loss(self.mms[i].weight)
         
        # i,j = self.m_pairs[0]
        # sim_l = cos_sim_loss_1d(self.mms[i].weight,self.mms[j].weight)
        # for k in range(1,len(self.m_pairs)):
            # i,j = self.m_pairs[k]
            # sim_l += cos_sim_loss_1d(self.mms[i].weight,self.mms[j].weight)
            
        return orth_l  # + sim_l 
    def apply_user_init(self):
        for bkey in range(self.m):
            user_normal_(self.mms[bkey].weight,m=self.m)
            if self.bias:
                self.mms[bkey].bias.data.fill_(0)
    def applyCong2(self):

        var_inv_sum = 0
        for bkey in range(self.m):
            var_inv_sum +=  1/self.mms[bkey].weight[:,:].var()

        self.subcnn.weight.data.fill_(0)
        if self.bias:
            self.subcnn.bias.data.fill_(0)
        for bkey in range(self.m):
            vrinv = (1/self.mms[bkey].weight[:,:].var())/var_inv_sum
            self.subcnn.weight[:,:] += \
            vrinv*self.mms[bkey].weight[:,:]
            if self.bias:
                self.subcnn.bias[:] += vrinv* self.mms[bkey].bias[:] 
                
    def applyCong(self):
#         bkey = 0 
#         self.sublin.weight[:,:] = self.mms[bkey].weight[:,:]
#         if self.bias:
#             self.sublin.bias[:] = self.mms[bkey].bias[:]
#         return
            
        weght = []
        for bkey in range(self.m):
            m =  self.mms[bkey].weight[:,:].mean()
            s =  self.mms[bkey].weight[:,:].std()
            ms = m/s
            weght.append(ms.item())
        weghtn = np.asarray(weght)
        weghtn = weghtn/weghtn.sum()
        for bkey in range(self.m):
            if bkey==0:
                self.sublin.weight[:,:] = \
                (weghtn[bkey])*self.mms[bkey].weight[:,:]
                if self.bias:
                    self.sublin.bias[:] = (weghtn[bkey])*self.mms[bkey].bias[:]
            else:
                self.sublin.weight[:,:] += \
                (weghtn[bkey])*self.mms[bkey].weight[:,:]
                if self.bias:
                    self.sublin.bias[:] += (weghtn[bkey])*self.mms[bkey].bias[:]     
    def forward(self,x):
        if self.domms:
            x_0 = self.minv[0]*F.relu(self.mms[0](x))
            for i in range(1,self.m):
                x_0 = x_0 +self.minv[i]*F.relu(self.mms[i](x))

            return x_0
        else:
#             print('sub lin')
            return F.relu(self.sublin(x))

class Linear_iea_nonlinear(nn.Module):
    def __init__(self, in_features, out_features,bias=True, m=3,nonlinea=F.tanh):
        super().__init__()
        self.mms = nn.ModuleList(
        [nn.Linear(in_features,out_features, bias=bias)
         for i in range(m)]
        )
        self.nonline = nonlinea
        self.m = m
        self.minv = nn.ParameterList([ nn.Parameter(torch.rand(1).fill_(1.0/self.m), requires_grad=False) for i in range(self.m) ])
        
        self.sublin = nn.Linear(in_features,out_features, bias=bias)
        self.domms = True
        self.bias = bias
        
        # self.m_pairs = [] 
        # for i in range(m):
            # for j in range(i,m):
                # if i != j: 
                    # self.m_pairs.append((i,j))
                    
    def calc_loss(self):
        orth_l = orth_loss(self.mms[0].weight)
        for i in range(1,self.m):
            orth_l+= orth_loss(self.mms[i].weight)
         
        # i,j = self.m_pairs[0]
        # sim_l = cos_sim_loss_1d(self.mms[i].weight,self.mms[j].weight)
        # for k in range(1,len(self.m_pairs)):
            # i,j = self.m_pairs[k]
            # sim_l += cos_sim_loss_1d(self.mms[i].weight,self.mms[j].weight)
            
        return orth_l  # + sim_l 
    def apply_user_init(self):
        for bkey in range(self.m):
            user_normal_(self.mms[bkey].weight,m=self.m)  
            if self.bias:
                self.mms[bkey].bias.data.fill_(0)
    def applyCong2(self):

        var_inv_sum = 0
        for bkey in range(self.m):
            var_inv_sum +=  1/self.mms[bkey].weight[:,:].var()

        self.subcnn.weight.data.fill_(0)
        if self.bias:
            self.subcnn.bias.data.fill_(0)
        for bkey in range(self.m):
            vrinv = (1/self.mms[bkey].weight[:,:].var())/var_inv_sum
            self.subcnn.weight[:,:] += \
            vrinv*self.mms[bkey].weight[:,:]
            if self.bias:
                self.subcnn.bias[:] += vrinv* self.mms[bkey].bias[:]      
    def applyCong(self):
#         bkey = 0 
#         self.sublin.weight[:,:] = self.mms[bkey].weight[:,:]
#         if self.bias:
#             self.sublin.bias[:] = self.mms[bkey].bias[:]
#         return
            
        weght = []
        for bkey in range(self.m):
            m =  self.mms[bkey].weight[:,:].mean()
            s =  self.mms[bkey].weight[:,:].std()
            ms = m/s
            weght.append(ms.item())
        weghtn = np.asarray(weght)
        weghtn = weghtn/weghtn.sum()
        for bkey in range(self.m):
            if bkey==0:
                self.sublin.weight[:,:] = \
                (weghtn[bkey])*self.mms[bkey].weight[:,:]
                if self.bias:
                    self.sublin.bias[:] = (weghtn[bkey])*self.mms[bkey].bias[:]
            else:
                self.sublin.weight[:,:] += \
                (weghtn[bkey])*self.mms[bkey].weight[:,:]
                if self.bias:
                    self.sublin.bias[:] += (weghtn[bkey])*self.mms[bkey].bias[:]     
    def forward(self,x):
        if self.domms:
            x_0 = self.minv[0]*F.relu(self.mms[0](x))
            for i in range(1,self.m):
                x_0 = x_0 +self.minv[i]*F.relu(self.mms[i](x))

            return x_0
        else:
#             print('sub lin')
            return F.relu(self.sublin(x))

class Linear_iea_nom(nn.Module):
    def __init__(self, in_features, out_features,bias=True, m=3,nonlinea=F.tanh):
        super().__init__()
        self.mms = nn.ModuleList(
        [nn.Linear(in_features,out_features, bias=bias)
         for i in range(m)]
        )
        self.nonline = nonlinea
        self.m = m
        self.minv = nn.ParameterList([ nn.Parameter(torch.rand(1).fill_(1.0), requires_grad=False) for i in range(self.m) ])
        
        self.sublin = nn.Linear(in_features,out_features, bias=bias)
        self.domms = True
        self.bias = bias
        
        # self.m_pairs = [] 
        # for i in range(m):
            # for j in range(i,m):
                # if i != j: 
                    # self.m_pairs.append((i,j))
                    
    def calc_loss(self):
        orth_l = orth_loss(self.mms[0].weight)
        for i in range(1,self.m):
            orth_l+= orth_loss(self.mms[i].weight)
         
        # i,j = self.m_pairs[0]
        # sim_l = cos_sim_loss_1d(self.mms[i].weight,self.mms[j].weight)
        # for k in range(1,len(self.m_pairs)):
            # i,j = self.m_pairs[k]
            # sim_l += cos_sim_loss_1d(self.mms[i].weight,self.mms[j].weight)
            
        return orth_l  # + sim_l 
    def apply_user_init(self):
        for bkey in range(self.m):
            user_normal_(self.mms[bkey].weight,m=self.m)   
            if self.bias:
                self.mms[bkey].bias.data.fill_(0)
    def applyCong2(self):

        var_inv_sum = 0
        for bkey in range(self.m):
            var_inv_sum +=  1/self.mms[bkey].weight[:,:].var()

        self.subcnn.weight.data.fill_(0)
        if self.bias:
            self.subcnn.bias.data.fill_(0)
        for bkey in range(self.m):
            vrinv = (1/self.mms[bkey].weight[:,:].var())/var_inv_sum
            self.subcnn.weight[:,:] += \
            vrinv*self.mms[bkey].weight[:,:]
            if self.bias:
                self.subcnn.bias[:] += vrinv* self.mms[bkey].bias[:]     
    def applyCong(self):
#         bkey = 0 
#         self.sublin.weight[:,:] = self.mms[bkey].weight[:,:]
#         if self.bias:
#             self.sublin.bias[:] = self.mms[bkey].bias[:]
#         return
            
        weght = []
        for bkey in range(self.m):
            m =  self.mms[bkey].weight[:,:].mean()
            s =  self.mms[bkey].weight[:,:].std()
            ms = m/s
            weght.append(ms.item())
        weghtn = np.asarray(weght)
        weghtn = weghtn/weghtn.sum()
        for bkey in range(self.m):
            if bkey==0:
                self.sublin.weight[:,:] = \
                (weghtn[bkey])*self.mms[bkey].weight[:,:]
                if self.bias:
                    self.sublin.bias[:] = (weghtn[bkey])*self.mms[bkey].bias[:]
            else:
                self.sublin.weight[:,:] += \
                (weghtn[bkey])*self.mms[bkey].weight[:,:]
                if self.bias:
                    self.sublin.bias[:] += (weghtn[bkey])*self.mms[bkey].bias[:]     
    def forward(self,x):
        if self.domms:
            x_0 = self.minv[0]*self.mms[0](x)
            for i in range(1,self.m):
                x_0 = x_0 +self.minv[i]*self.mms[i](x)

            return x_0
        else:
#             print('sub lin')
            return self.sublin(x)
            

class Linear_iea(nn.Module):
    def __init__(self, in_features, out_features,bias=True, m=3,nonlinea=F.tanh):
        super().__init__()
        self.mms = nn.ModuleList(
        [nn.Linear(in_features,out_features, bias=bias)
         for i in range(m)]
        )
        self.nonline = nonlinea
        self.m = m
        self.minv = nn.ParameterList([ nn.Parameter(torch.rand(1).fill_(1.0/self.m), requires_grad=False) for i in range(self.m) ])
        
        self.sublin = nn.Linear(in_features,out_features, bias=bias)
        self.domms = True
        self.bias = bias
        
        # self.m_pairs = [] 
        # for i in range(m):
            # for j in range(i,m):
                # if i != j: 
                    # self.m_pairs.append((i,j))
    def apply_user_init(self):
        for bkey in range(self.m):
            user_normal_(self.mms[bkey].weight,m=self.m)   
#             if self.bias:
#                 self.mms[bkey].bias.data.fill_(0)
    def calc_loss(self):
        orth_l = orth_loss(self.mms[0].weight)
        for i in range(1,self.m):
            orth_l+= orth_loss(self.mms[i].weight)
         
        # i,j = self.m_pairs[0]
        # sim_l = cos_sim_loss_1d(self.mms[i].weight,self.mms[j].weight)
        # for k in range(1,len(self.m_pairs)):
            # i,j = self.m_pairs[k]
            # sim_l += cos_sim_loss_1d(self.mms[i].weight,self.mms[j].weight)
            
        return orth_l  # + sim_l 
    def applyCong2(self):

        var_inv_sum = 0
        for bkey in range(self.m):
            var_inv_sum +=  1/self.mms[bkey].weight[:,:].var()

        self.subcnn.weight.data.fill_(0)
        if self.bias:
            self.subcnn.bias.data.fill_(0)

        for bkey in range(self.m):
            vrinv = (1/self.mms[bkey].weight[:,:].var())/var_inv_sum
            self.subcnn.weight[:,:] += \
            vrinv*self.mms[bkey].weight[:,:]
            if self.bias:
                self.subcnn.bias[:] += vrinv* self.mms[bkey].bias[:]    
    def applyCong(self):
#         bkey = 0 
#         self.sublin.weight[:,:] = self.mms[bkey].weight[:,:]
#         if self.bias:
#             self.sublin.bias[:] = self.mms[bkey].bias[:]
#         return
            
        weght = []
        for bkey in range(self.m):
            m =  self.mms[bkey].weight[:,:].mean()
            s =  self.mms[bkey].weight[:,:].std()
            ms = m/s
            weght.append(ms.item())
        weghtn = np.asarray(weght)
        weghtn = weghtn/weghtn.sum()
        for bkey in range(self.m):
            if bkey==0:
                self.sublin.weight[:,:] = \
                (weghtn[bkey])*self.mms[bkey].weight[:,:]
                if self.bias:
                    self.sublin.bias[:] = (weghtn[bkey])*self.mms[bkey].bias[:]
            else:
                self.sublin.weight[:,:] += \
                (weghtn[bkey])*self.mms[bkey].weight[:,:]
                if self.bias:
                    self.sublin.bias[:] += (weghtn[bkey])*self.mms[bkey].bias[:]     
    def forward(self,x):
        if self.domms:
            x_0 = self.minv[0]*self.mms[0](x)
            for i in range(1,self.m):
                x_0 = x_0 +self.minv[i]*self.mms[i](x)

            return x_0
        else:
#             print('sub lin')
            return self.sublin(x)
            


class Linear_maxout(nn.Module):
    def __init__(self, in_features, out_features,bias=True, m=3):
        super().__init__()
        self.mms = nn.ModuleList(
        [nn.Linear(in_features,out_features, bias=bias)
         for i in range(m)]
        )
        self.m = m
        
        self.sublin = nn.Linear(in_features,out_features, bias=bias)
        self.domms = True
        self.bias = bias
    def applyCong(self):
        pass
 
    def calc_loss(self):
        orth_l = orth_loss(self.mms[0].weight)
        for i in range(1,self.m):
            orth_l+= orth_loss(self.mms[i].weight)
            
        return orth_l   #+ sim_l 
        
    def forward(self,x):
        if self.domms:
            x_0 = self.mms[0](x)
            for i in range(1,self.m):
                x_0 = torch.max(x_0,self.mms[i](x))

            return x_0
        else:
            return self.sublin(x)
            


        