import numpy as np
import torch
from torch import Tensor
import torch.nn as nn

#Initialization code used.
net = baseNN.improved(input_shape = 512+32, output_shape = 12, hidden_dims = [402,280,195,136,94], domain = 58, use_bn=False,use_routing2=4,routing_weight=0.333,routing_depth=1,routing_side_depth=3,rout_input=True,rout_network_input=True,routing_side_num=1,routing_side_train=[1],base_weight_net=[80,40,'auto'],routing_width_x=1.5*0.5,rout_weight_net=[80,40,'auto'],rout_init_val=-2)

#Class for DPNet
class Improved(nn.Module):
    def __init__(self, cfg):
        super(Improved, self).__init__()
        self.cfg = cfg
        self.switch_list_already_loaded=False
        self.channels = [cfg.input_shape, *cfg.hidden_dims, cfg.output_shape]
        self.layers = nn.ModuleList()
        activation = get_activation('elu' if not hasattr(cfg, 'activation') else cfg.activation)
        for i in range(1, len(self.channels)):
            if self.cfg.use_bn==False:
                self.layers.append(
                    nn.Sequential(
                        nn.Linear(self.channels[i-1], self.channels[i]),
                        activation
                    )
                )
            else:
                self.layers.append(
                    nn.Sequential(
                        nn.Linear(self.channels[i-1], self.channels[i]),
                        nn.BatchNorm1d(self.channels[i]),
                        activation
                    )
                )
        self.rout_activation='relu'
        try:
            self.rout_activation=cfg.rout_activation
        except:
            self.rout_activation='relu'
        self.rout_activation=get_activation(self.rout_activation)
        self.use_residual=False
        try:
            self.use_residual=cfg.use_residual
        except:
            self.use_residual=False
        self.routing_val=0
        try:
            self.routing_val=cfg.use_routing2
        except:
            self.routing_val=0
        self.routing_weight=0.1
        try:
            self.routing_weight=cfg.routing_weight
        except:
            self.routing_weight=0.1
        self.routing_depth=1
        try:
            self.routing_depth=cfg.routing_depth
        except:
            self.routing_depth=1
        self.routing_side_depth=0
        try:
            self.routing_side_depth=cfg.routing_side_depth
        except:
            self.routing_side_depth=0
        if self.routing_side_depth!=0:
            self.routing_weight=self.routing_weight/2
        self.init_side_std=None
        try:
            self.init_side_std=cfg.init_side_std
        except:
            self.init_side_std=None
        self.rout_input=False
        try:
            self.rout_input=cfg.rout_input
        except:
            self.rout_input=False
        self.rout_network_input=False
        try:
            self.rout_network_input=cfg.rout_network_input
        except:
            self.rout_network_input=False
        self.rout_to_output=0
        try:
            self.rout_to_output=cfg.rout_to_output
        except:
            self.rout_to_output=0
        self.switch_percent=None
        try:
            self.switch_percent=cfg.switch_percent
        except:
            self.switch_percent=None
        self.routing_side_num=None
        try:
            self.routing_side_num=cfg.routing_side_num
        except:
            self.routing_side_num=None
        self.routing_side_train=None
        try:
            self.routing_side_train=cfg.routing_side_train
        except:
            self.routing_side_train=None
        self.gate_dims=None
        try:
            self.gate_dims=cfg.gate_dims
        except:
            self.gate_dims=None
        self.gate_top_k=None
        try:
            self.gate_top_k=cfg.gate_top_k
        except:
            self.gate_top_k=None
        self.random_gate_percent=None
        try:
            self.random_gate_percent=cfg.random_gate_percent
        except:
            self.random_gate_percent=None
        self.base_weight_net=None
        try:
            self.base_weight_net=cfg.base_weight_net
        except:
            self.base_weight_net=None
        self.rout_weight_net=None
        try:
            self.rout_weight_net=cfg.rout_weight_net
        except:
            self.rout_weight_net=None
        self.rout_init_val=5
        try:
            self.rout_init_val=cfg.rout_init_val
        except:
            self.rout_init_val=5
        self.rout_topk=None
        try:
            self.rout_topk=cfg.rout_topk
        except:
            self.rout_topk=None
        self.rout_rand_gate_percent=None
        try:
            self.rout_rand_gate_percent=cfg.rout_rand_gate_percent
        except:
            self.rout_rand_gate_percent=None
        self.diff_moe_branch=False
        try:
            self.diff_moe_branch=cfg.diff_moe_branch
        except:
            self.diff_moe_branch=False
        self.moe_residual=False
        try:
            self.moe_residual=cfg.moe_residual
        except:
            self.moe_residual=False
        self.routing_width_x=1
        try:
            self.routing_width_x=cfg.routing_width_x
        except:
            self.routing_width_x=1
        self.rout_out_act='sigmoid'
        try:
            self.rout_out_act=cfg.rout_out_act
        except:
            self.rout_out_act=1
    def create_sequential(self,input_dim,hidden_dim,output_dim,hidden_act,depth):
        layers = []
        for i in range(depth - 1):
            if i==0:
                layers.append(nn.Linear(input_dim, hidden_dim))
                layers.append(hidden_act)
            else:
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                layers.append(hidden_act)
        if depth==1:
            layers.append(nn.Linear(input_dim, output_dim))
        else:
            layers.append(nn.Linear(hidden_dim, output_dim))
        layers.append(nn.Tanh())
        return nn.Sequential(*layers)
    def create_gating_net(self,input_dim,hidden_dims,output_dim,hidden_act):
        layers=[]
        cur_index=0
        cur_in=input_dim
        cur_out=hidden_dims[cur_index]
        for i in range(len(hidden_dims)):
            layers.append(nn.Linear(cur_in, cur_out))
            layers.append(hidden_act)
            cur_in=hidden_dims[cur_index]
            if cur_index+1<len(hidden_dims):
                cur_out=hidden_dims[cur_index+1]
            cur_index=cur_index+1
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        layers.append(nn.Softmax(dim=-1))
        return nn.Sequential(*layers)
    def init_routing(self,device):
        rout_block_num=0
        if self.base_weight_net!=None:
            self.base_weight_net_list=nn.ModuleList()
            for i in range(len(self.channels)-2):
                cur_dim=copy.deepcopy(self.base_weight_net)
                cur_dim.insert(0,self.channels[i])
                base_weight_net = nn.ModuleList()
                for j in range(len(cur_dim) - 1):
                    if isinstance(cur_dim[j+1],str)==False:
                        base_weight_net.append(nn.Linear(cur_dim[j], cur_dim[j+1]))
                    else:
                        base_weight_net.append(nn.Linear(cur_dim[j], self.channels[i+1]))
                    if j<len(cur_dim) - 2:
                        base_weight_net.append(nn.ReLU())
                    else:
                        if self.rout_out_act=='ReLU':
                            base_weight_net.append(nn.ReLU())
                        elif self.rout_out_act=='GeLU':
                            base_weight_net.append(nn.GELU())
                        else:
                            base_weight_net.append(nn.Sigmoid())
                for j, layer in enumerate(base_weight_net):
                    if isinstance(layer, nn.Linear):
                        nn.init.normal_(layer.weight, mean=0.0, std=0.01)
                        if j == len(base_weight_net) - 2:
                            nn.init.constant_(layer.bias, 5.0)
                        else:
                            nn.init.constant_(layer.bias, 0.0)
                self.base_weight_net_list.append(nn.Sequential(*base_weight_net))
            self.base_weight_net_list=self.base_weight_net_list.to(device)
        if self.rout_weight_net!=None:
            self.rout_weight_net_list=nn.ModuleList()
            for i in range(len(self.channels)-2):
                cur_dim=copy.deepcopy(self.rout_weight_net)
                cur_dim.insert(0,self.channels[i])
                rout_weight_net = nn.ModuleList()
                for j in range(len(cur_dim) - 1):
                    if isinstance(cur_dim[j+1],str)==False:
                        rout_weight_net.append(nn.Linear(cur_dim[j], cur_dim[j+1]))
                    else:
                        rout_weight_net.append(nn.Linear(cur_dim[j], self.channels[i+1]))
                    if j<len(cur_dim) - 2:
                        rout_weight_net.append(nn.ReLU())
                    else:
                        if self.rout_out_act=='ReLU':
                            rout_weight_net.append(nn.ReLU())
                        elif self.rout_out_act=='GeLU':
                            rout_weight_net.append(nn.GELU())
                        else:
                            rout_weight_net.append(nn.Sigmoid())
                for j, layer in enumerate(rout_weight_net):
                    if isinstance(layer, nn.Linear):
                        nn.init.normal_(layer.weight, mean=0.0, std=0.01)
                        if j == len(rout_weight_net) - 2:
                            nn.init.constant_(layer.bias, self.rout_init_val)
                        else:
                            nn.init.constant_(layer.bias, 0.0)
                self.rout_weight_net_list.append(nn.Sequential(*rout_weight_net))
            self.rout_weight_net_list=self.rout_weight_net_list.to(device)
        self.routing_layer_list=[]
        if self.switch_percent!=None:
            self.switch_lists=[]
            self.switch_lists1=[]
            self.switch_lists2=[]
            self.switch_lists3=[]
        else:
            self.switch_lists=None
            self.switch_lists1=None
            self.switch_lists2=None
            self.switch_lists3=None
        if self.rout_network_input==True:
            if self.routing_val>0:
                i=0
                self.rout_netinput=self.create_sequential(input_dim=self.channels[i],hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i],hidden_act=self.rout_activation,depth=self.routing_depth)
                self.rout_netinput=self.rout_netinput.to(device)
                self.routing_layer_list.append(self.rout_netinput)
                self.rout_netinput_side=None
                self.rout_netinput_side=self.create_sequential(input_dim=self.channels[i],hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i],hidden_act=self.rout_activation,depth=self.routing_side_depth)
                self.rout_netinput_side=self.rout_netinput_side.to(device)
                self.routing_layer_list.append(self.rout_netinput_side)
            if self.routing_val>1:
                i=0
                self.rout_netinput1=self.create_sequential(input_dim=self.channels[i],hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+1],hidden_act=self.rout_activation,depth=self.routing_depth)
                self.rout_netinput1=self.rout_netinput1.to(device)
                self.routing_layer_list.append(self.rout_netinput1)
                self.rout_netinput_side1=None
                self.rout_netinput_side1=self.create_sequential(input_dim=self.channels[i],hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+1],hidden_act=self.rout_activation,depth=self.routing_side_depth)
                self.rout_netinput_side1=self.rout_netinput_side1.to(device)
                self.routing_layer_list.append(self.rout_netinput_side1)
                if self.rout_topk!=None:
                    rout_block_num+=1
            if self.routing_val>2:
                i=0
                self.rout_netinput2=self.create_sequential(input_dim=self.channels[i],hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+2],hidden_act=self.rout_activation,depth=self.routing_depth)
                self.rout_netinput2=self.rout_netinput2.to(device)
                self.routing_layer_list.append(self.rout_netinput2)
                self.rout_netinput_side2=None
                self.rout_netinput_side2=self.create_sequential(input_dim=self.channels[i],hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+2],hidden_act=self.rout_activation,depth=self.routing_side_depth)
                self.rout_netinput_side2=self.rout_netinput_side2.to(device)
                self.routing_layer_list.append(self.rout_netinput_side2)
                if self.rout_topk!=None:
                    rout_block_num+=1
            if self.routing_val>3:
                i=0
                self.rout_netinput3=self.create_sequential(input_dim=self.channels[i],hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+3],hidden_act=self.rout_activation,depth=self.routing_depth)
                self.rout_netinput3=self.rout_netinput3.to(device)
                self.routing_layer_list.append(self.rout_netinput3)
                self.rout_netinput_side3=None
                self.rout_netinput_side3=self.create_sequential(input_dim=self.channels[i],hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+3],hidden_act=self.rout_activation,depth=self.routing_side_depth)
                self.rout_netinput_side3=self.rout_netinput_side3.to(device)
                self.routing_layer_list.append(self.rout_netinput_side3)
                if self.rout_topk!=None:
                    rout_block_num+=1
        if self.routing_val>0:
            self.routing_layers = nn.ModuleList()
            for i in range(1,len(self.cfg.hidden_dims)+1,1):
                if self.rout_input==False:
                    input_dim=self.channels[i]
                else:
                    input_dim=self.channels[i-1]
                self.routing_layers.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i],hidden_act=self.rout_activation,depth=self.routing_depth))
                if self.switch_percent!=None and self.switch_list_already_loaded==False:
                    self.switch_lists.append(np.arange(0,self.channels[i],1))
            self.routing_side_layers = nn.ModuleList()
            for i in range(1,len(self.cfg.hidden_dims)+1,1):
                if self.rout_input==False:
                    input_dim=self.channels[i]
                else:
                    input_dim=self.channels[i-1]
                self.routing_side_layers.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i],hidden_act=self.rout_activation,depth=self.routing_side_depth))
            self.routing_side_layers=self.routing_side_layers.to(device)
            if self.init_side_std!=None:
                for seq in self.routing_side_layers:
                    for layer in seq:
                        if isinstance(layer, nn.Linear):
                            nn.init.normal_(layer.weight, mean=0.0, std=self.init_side_std)
                            if layer.bias is not None:
                                nn.init.constant_(layer.bias, 0)
            self.routing_layers=self.routing_layers.to(device)
            self.routing_layer_list.append(self.routing_layers)
        if self.routing_val>1:
            self.routing_layers1 = nn.ModuleList()
            for i in range(1,len(self.cfg.hidden_dims)-1+1+self.rout_to_output,1):
                if self.rout_input==False:
                    input_dim=self.channels[i]
                else:
                    input_dim=self.channels[i-1]
                self.routing_layers1.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+1],hidden_act=self.rout_activation,depth=self.routing_depth))
                if self.rout_topk!=None:
                    rout_block_num+=1
                if self.switch_percent!=None and self.switch_list_already_loaded==False:
                    self.switch_lists1.append(np.arange(0,self.channels[i+1],1))
            self.routing_layers1_side = nn.ModuleList()
            if self.gate_dims!=None:
                self.side1_gating=nn.ModuleList()
            for i in range(1,len(self.cfg.hidden_dims)-1+1+self.rout_to_output,1):
                if self.rout_input==False:
                    input_dim=self.channels[i]
                else:
                    input_dim=self.channels[i-1]
                if self.routing_side_num==None:
                    self.routing_layers1_side.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+1],hidden_act=self.rout_activation,depth=self.routing_side_depth))
                else:
                    seq_layer_list=nn.ModuleList()
                    for j in range(self.routing_side_num):
                        if self.diff_moe_branch==False:
                            seq_layer_list.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+1],hidden_act=self.rout_activation,depth=self.routing_side_depth))
                        else:
                            hidden_dim_cur=int(self.channels[i]*0.5+self.channels[i]*0.5*(j/self.routing_side_num))
                            depth_cur=int(1+(self.routing_side_depth-1)*(j/self.routing_side_num))
                            seq_layer_list.append(self.create_sequential(input_dim=input_dim,hidden_dim=hidden_dim_cur,output_dim=self.channels[i+1],hidden_act=self.rout_activation,depth=depth_cur))
                    self.routing_layers1_side.append(seq_layer_list)
                    if self.gate_dims!=None:
                        self.side1_gating.append(self.create_gating_net(input_dim=input_dim,hidden_dims=self.gate_dims,output_dim=self.routing_side_num,hidden_act=self.rout_activation))
            self.routing_layers1_side=self.routing_layers1_side.to(device)
            if self.gate_dims!=None:
                self.side1_gating=self.side1_gating.to(device)
            if self.routing_side_num==None:
                if self.init_side_std!=None:
                    for seq in self.routing_layers1_side:
                        for layer in seq:
                            if isinstance(layer, nn.Linear):
                                nn.init.normal_(layer.weight, mean=0.0, std=self.init_side_std)
                                if layer.bias is not None:
                                    nn.init.constant_(layer.bias, 0)
                else:
                    if self.init_side_std!=None:
                        for seq_layer_list in self.routing_layers1_side:
                            for seq in seq_layer_list:
                                for layer in seq:
                                    if isinstance(layer, nn.Linear):
                                        nn.init.normal_(layer.weight, mean=0.0, std=self.init_side_std)
                                        if layer.bias is not None:
                                            nn.init.constant_(layer.bias, 0)
            self.routing_layers1=self.routing_layers1.to(device)
            self.routing_layer_list.append(self.routing_layers1)
        if self.routing_val>2:
            self.routing_layers2 = nn.ModuleList()
            for i in range(1,len(self.cfg.hidden_dims)-2+1+self.rout_to_output,1):
                if self.rout_input==False:
                    input_dim=self.channels[i]
                else:
                    input_dim=self.channels[i-1]
                self.routing_layers2.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+2],hidden_act=self.rout_activation,depth=self.routing_depth))
                if self.rout_topk!=None:
                    rout_block_num+=1
                if self.switch_percent!=None and self.switch_list_already_loaded==False:
                    self.switch_lists2.append(np.arange(0,self.channels[i+2],1))
            self.routing_layers2_side = nn.ModuleList()
            if self.gate_dims!=None:
                self.side2_gating=nn.ModuleList()
            for i in range(1,len(self.cfg.hidden_dims)-2+1+self.rout_to_output,1):
                if self.rout_input==False:
                    input_dim=self.channels[i]
                else:
                    input_dim=self.channels[i-1]
                if self.routing_side_num==None:
                    self.routing_layers2_side.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+2],hidden_act=self.rout_activation,depth=self.routing_side_depth))
                else:
                    seq_layer_list=nn.ModuleList()
                    for j in range(self.routing_side_num):
                        if self.diff_moe_branch==False:
                            seq_layer_list.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+2],hidden_act=self.rout_activation,depth=self.routing_side_depth))
                        else:
                            hidden_dim_cur=int(self.channels[i]*0.5+self.channels[i]*0.5*(j/self.routing_side_num))
                            depth_cur=int(1+(self.routing_side_depth-1)*(j/self.routing_side_num))
                            seq_layer_list.append(self.create_sequential(input_dim=input_dim,hidden_dim=hidden_dim_cur,output_dim=self.channels[i+2],hidden_act=self.rout_activation,depth=depth_cur))
                    self.routing_layers2_side.append(seq_layer_list)
                    if self.gate_dims!=None:
                        self.side2_gating.append(self.create_gating_net(input_dim=input_dim,hidden_dims=self.gate_dims,output_dim=self.routing_side_num,hidden_act=self.rout_activation))
            self.routing_layers2_side=self.routing_layers2_side.to(device)
            if self.gate_dims!=None:
                self.side2_gating=self.side2_gating.to(device)
            if self.routing_side_num==None:
                if self.init_side_std!=None:
                    for seq in self.routing_layers2_side:
                        for layer in seq:
                            if isinstance(layer, nn.Linear):
                                nn.init.normal_(layer.weight, mean=0.0, std=self.init_side_std)
                                if layer.bias is not None:
                                    nn.init.constant_(layer.bias, 0)
            else:
                if self.init_side_std!=None:
                    for seq_layer_list in self.routing_layers2_side:
                        for seq in seq_layer_list:
                            for layer in seq:
                                if isinstance(layer, nn.Linear):
                                    nn.init.normal_(layer.weight, mean=0.0, std=self.init_side_std)
                                    if layer.bias is not None:
                                        nn.init.constant_(layer.bias, 0)
            self.routing_layers2=self.routing_layers2.to(device)
            self.routing_layer_list.append(self.routing_layers2)
        if self.routing_val>3:
            self.routing_layers3 = nn.ModuleList()
            for i in range(1,len(self.cfg.hidden_dims)-3+1+self.rout_to_output,1):
                if self.rout_input==False:
                    input_dim=self.channels[i]
                else:
                    input_dim=self.channels[i-1]
                self.routing_layers3.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+3],hidden_act=self.rout_activation,depth=self.routing_depth))
                if self.rout_topk!=None:
                    rout_block_num+=1
                if self.switch_percent!=None and self.switch_list_already_loaded==False:
                    self.switch_lists3.append(np.arange(0,self.channels[i+3],1))
            self.routing_layers3_side = nn.ModuleList()
            if self.gate_dims!=None:
                self.side3_gating=nn.ModuleList()
            for i in range(1,len(self.cfg.hidden_dims)-3+1+self.rout_to_output,1):
                if self.rout_input==False:
                    input_dim=self.channels[i]
                else:
                    input_dim=self.channels[i-1]
                if self.routing_side_num==None:
                    self.routing_layers3_side.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+3],hidden_act=self.rout_activation,depth=self.routing_side_depth))
                else:
                    seq_layer_list=nn.ModuleList()
                    for j in range(self.routing_side_num):
                        if self.diff_moe_branch==False:
                            seq_layer_list.append(self.create_sequential(input_dim=input_dim,hidden_dim=int(self.channels[i]*self.routing_width_x),output_dim=self.channels[i+3],hidden_act=self.rout_activation,depth=self.routing_side_depth))
                        else:
                            hidden_dim_cur=int(self.channels[i]*0.5+self.channels[i]*0.5*(j/self.routing_side_num))
                            depth_cur=int(1+(self.routing_side_depth-1)*(j/self.routing_side_num))
                            seq_layer_list.append(self.create_sequential(input_dim=input_dim,hidden_dim=hidden_dim_cur,output_dim=self.channels[i+3],hidden_act=self.rout_activation,depth=depth_cur))
                    self.routing_layers3_side.append(seq_layer_list)
                    if self.gate_dims!=None:
                        self.side3_gating.append(self.create_gating_net(input_dim=input_dim,hidden_dims=self.gate_dims,output_dim=self.routing_side_num,hidden_act=self.rout_activation))
            self.routing_layers3_side=self.routing_layers3_side.to(device)
            if self.gate_dims!=None:
                self.side3_gating=self.side3_gating.to(device)
            if self.routing_side_num==None:
                if self.init_side_std!=None:
                    for seq in self.routing_layers3_side:
                        for layer in seq:
                            if isinstance(layer, nn.Linear):
                                nn.init.normal_(layer.weight, mean=0.0, std=self.init_side_std)
                                if layer.bias is not None:
                                    nn.init.constant_(layer.bias, 0)
            else:
                if self.init_side_std!=None:
                    for seq_layer_list in self.routing_layers3_side:
                        for seq in self.routing_layers3_side:
                            for layer in seq:
                                if isinstance(layer, nn.Linear):
                                    nn.init.normal_(layer.weight, mean=0.0, std=self.init_side_std)
                                    if layer.bias is not None:
                                        nn.init.constant_(layer.bias, 0)
            self.routing_layers3=self.routing_layers3.to(device)
            self.routing_layer_list.append(self.routing_layers3)
        self.all_switch_lists=[self.switch_lists,
                                self.switch_lists1,
                                self.switch_lists2,
                                self.switch_lists3]
        if self.rout_topk!=None:
            self.block_rout_weight = nn.Parameter(torch.randn(rout_block_num, ))
    def switch_entries(self,lst, P=0):
        n = len(lst)
        num_switches = int((n * P*100) // 200)
        switched_entries=[]
        while len(switched_entries) < num_switches*2:
            i, j = random.sample(range(n), 2)
            if i not in switched_entries and j not in switched_entries:
                lst[i], lst[j] = lst[j], lst[i]
                switched_entries.append(i)
                switched_entries.append(j)
        return lst
    def switch_order(self,x,order_list):
        if order_list is None:
            return x
        else:
            if len(x.shape)==2:
                x=x[:,order_list]
            else:
                x=x[:,:,order_list]
            return x
    def reorder_switch_list(self):
        for i in range(len(self.all_switch_lists)):
            for j in range(len(self.all_switch_lists[i])):
                self.all_switch_lists[i][j]=self.switch_entries(self.all_switch_lists[i][j],self.switch_percent)
    def save_switch_list(self,directory):
        with open(directory, 'wb') as file:
            pickle.dump(self.all_switch_lists, file)
    def load_switch_list_function(self,directory):
        with open(directory, 'rb') as file:
            self.all_switch_lists = pickle.load(file)
        self.switch_lists=self.all_switch_lists[0]
        self.switch_lists1=self.all_switch_lists[1]
        self.switch_lists2=self.all_switch_lists[2]
        self.switch_lists3=self.all_switch_lists[3]
        self.switch_list_already_loaded=True
    def routing_layer_forward(self,x,routing_layer,routing_side_layer=None,x_in=None,order_list=None):
        if self.rout_input==False:
            x_=x_in
        else:
            x_=x
        if routing_side_layer==None:
            x=x_+self.switch_order(routing_layer(x)*self.routing_weight,order_list)
        else:
            if self.rout_input==True:
                x=x_in+self.switch_order(routing_layer(x_)*self.routing_weight + routing_side_layer(x_)*self.routing_weight,order_list)
            else:
                x=x_+self.switch_order(routing_layer(x_)*self.routing_weight + routing_side_layer(x_)*self.routing_weight,order_list)
        return x
    def top_k_gating(self, tensor, gate_top_k, random_gate_percent=0.0, routing_side_train=None):
        """
        Args:
            tensor: Input tensor of shape (..., feature_dim)
            gate_top_k: Number of entries to keep as non-zero in last dimension
            random_gate_percent: Proportion of rows to use random selection (0.0-1.0)
            routing_side_train: List of 0s and 1s indicating which entries are eligible for selection
        Returns:
            Tensor where specified rows have random k entries preserved instead of top-k,
            and only entries corresponding to 1s in routing_side_train are considered.
        """
        if tensor.dim() < 2:
            raise ValueError("Input tensor must have at least 2 dimensions.")
        device = tensor.device
        shape = tensor.shape
        tensor_2d = tensor.view(-1, shape[-1])
        num_rows = tensor_2d.size(0)
        mask = torch.zeros_like(tensor_2d).to(device)
        if routing_side_train is not None:
            routing_mask = torch.tensor(routing_side_train, dtype=torch.bool, device=device).unsqueeze(0).expand(num_rows, -1)
        else:
            routing_mask = torch.ones((num_rows, shape[-1]), dtype=torch.bool, device=device)
        num_random = int(num_rows * random_gate_percent)
        random_rows = torch.tensor([], dtype=torch.long, device=device)
        if num_random > 0:
            random_rows = torch.randperm(num_rows, device=device)[:num_random]
            random_scores = torch.rand_like(tensor_2d[random_rows])
            random_scores[~routing_mask[random_rows]] = -float('inf')
            _, random_indices = torch.topk(random_scores, k=gate_top_k, dim=-1)
            mask[random_rows] = mask[random_rows].scatter(-1, random_indices, 1.0)
        remaining_rows = list(set(range(num_rows)) - set(random_rows.tolist()))
        if remaining_rows:
            remaining_tensor = tensor_2d[remaining_rows]
            remaining_tensor_masked = remaining_tensor.clone()
            remaining_tensor_masked[~routing_mask[remaining_rows]] = -float('inf')
            _, topk_indices = torch.topk(remaining_tensor_masked, k=gate_top_k, dim=-1)
            mask[remaining_rows] = mask[remaining_rows].scatter(-1, topk_indices, 1.0)
        gated = tensor_2d * mask
        return gated.view(shape)
    def seq_forward(self,seq,input,moe_residual=False):
        if moe_residual==False:
            return seq(input)
        elif moe_residual==1:
            if len(seq)<=4:
                return seq(input)
            else:
                cur_x=input
                hidden_x=None
                for i, layer in enumerate(seq):
                    if i<2:
                        cur_x=layer(cur_x)
                    elif i>=2 and i<len(seq)-2:
                        if hidden_x==None:
                            hidden_x=layer(cur_x)
                        else:
                            hidden_x=layer(hidden_x)
                    else:
                        if i==len(seq)-2:
                            out=layer(hidden_x+cur_x)
                        else:
                            out=layer(out)
                return out
        else:
            if len(seq)<=4:
                return seq(input)
            else:
                cur_x=input
                hidden_x=None
                for i, layer in enumerate(seq):
                    if i<2:
                        cur_x=layer(cur_x)
                    elif i>=2 and i<len(seq)-2:
                        if hidden_x==None:
                            hidden_x=layer(cur_x)
                        else:
                            hidden_x=layer(hidden_x)
                        if i%2==0:
                            cur_x=cur_x+hidden_x
                    else:
                        if i==len(seq)-2:
                            out=layer(hidden_x+cur_x)
                        else:
                            out=layer(out)
                return out
    def forward(self, x):
        reshaped=False
        if len(x.shape)==3:
            reshaped=x.shape
            x=torch.reshape(x,[reshaped[0]*reshaped[1],reshaped[2]])
        counter=0
        routing1_out=[None,None]
        routing2_out=[None,None,None]
        routing3_out=[None,None,None,None]
        if self.rout_topk!=None:
            block_rout_mask=self.top_k_gating(tensor=torch.reshape(self.block_rout_weight,[1,self.block_rout_weight.shape[0]]), gate_top_k=self.rout_topk, random_gate_percent=self.rout_rand_gate_percent, routing_side_train=None)
            block_rout_index=0
            print('rout_block_num',self.block_rout_weight.shape)
        x_routed=x
        x_in=x
        if self.rout_network_input==True:
            if self.routing_val>0:
                if self.routing_side_depth==0:
                    x=self.routing_layer_forward(x_routed,self.rout_netinput,x_in=x_in)
                else:
                    x=self.routing_layer_forward(x_routed,self.rout_netinput,self.rout_netinput_side,x_in=x_in)
            if self.routing_val>1:
                routing1_out.pop(0)
                routing1_out.append(None)
                if self.routing_side_depth==0:
                    xr1=self.rout_netinput1(x_routed)*self.routing_weight
                else:
                    xr1=self.rout_netinput1(x_routed)*self.routing_weight+self.rout_netinput_side1(x_routed)*self.routing_weight
                if self.rout_topk!=None:
                    xr1=xr1*block_rout_mask[0,block_rout_index]
                    block_rout_index+=1
                routing1_out[1]=xr1
            if self.routing_val>2:
                routing2_out.pop(0)
                routing2_out.append(None)
                if self.routing_side_depth==0:
                    xr2=self.rout_netinput2(x_routed)*self.routing_weight
                else:
                    xr2=self.rout_netinput2(x_routed)*self.routing_weight+self.rout_netinput_side2(x_routed)*self.routing_weight
                if self.rout_topk!=None:
                    xr2=xr2*block_rout_mask[0,block_rout_index]
                    block_rout_index+=1
                routing2_out[2]=xr2
            if self.routing_val>3:
                routing3_out.pop(0)
                routing3_out.append(None)
                if self.routing_side_depth==0:
                    xr3=self.rout_netinput3(x_routed)*self.routing_weight
                else:
                    xr3=self.rout_netinput3(x_routed)*self.routing_weight+self.rout_netinput_side3(x_routed)*self.routing_weight
                if self.rout_topk!=None:
                    xr3=xr3*block_rout_mask[0,block_rout_index]
                    block_rout_index+=1
                routing3_out[3]=xr3
        layer_counter=0
        for layer in self.layers[:-1]:
            x_ori=x
            cur_base_weight=1
            if self.base_weight_net!=None:
                cur_base_weight=self.base_weight_net_list[layer_counter](x)
                print(layer_counter,'cur_base_weight',torch.min(cur_base_weight),torch.mean(cur_base_weight),torch.max(cur_base_weight),torch.std(cur_base_weight),cur_base_weight.shape)
            cur_rout_weight=1
            if self.rout_weight_net!=None:
                cur_rout_weight=self.rout_weight_net_list[layer_counter](x)
                print(layer_counter,'cur_rout_weight',torch.min(cur_rout_weight),torch.mean(cur_rout_weight),torch.max(cur_rout_weight),torch.std(cur_rout_weight),cur_rout_weight.shape)
            layer_counter+=1
            if self.channels[counter]==self.channels[counter+1] and self.use_residual==True:
                x_in = cur_base_weight*layer(x)+x
            else:
                x_in = cur_base_weight*layer(x)
            if self.routing_val==0:
                x=x_in
            if self.rout_input==True:
                x_routed=x_ori
            else:
                x_routed=x_in
            '''
            self.switch_lists=[]
            self.switch_lists1=[]
            self.switch_lists2=[]
            self.switch_lists3=[]
            self.switch_order(routing_layer(x)*self.routing_weight,order_list)
            '''
            if self.routing_val>0:
                if self.switch_lists is None:
                    switch_list_applied=None
                else:
                    switch_list_applied=self.switch_lists[counter]
                if self.routing_side_depth==0:
                    x=self.routing_layer_forward(x_ori,self.routing_layers[counter],x_in=x_in,order_list=self.switch_lists[counter])
                else:
                    x=self.routing_layer_forward(x_ori,self.routing_layers[counter],self.routing_side_layers[counter],x_in=x_in,order_list=switch_list_applied)
            if self.routing_val>1:
                routing1_out.pop(0)
                routing1_out.append(None)
                if counter<len(self.cfg.hidden_dims)-1+self.rout_to_output:
                    if self.switch_lists1 is None:
                        switch_list_applied=None
                    else:
                        switch_list_applied=self.switch_lists1[counter]
                    if self.routing_side_depth==0:
                        xr1=self.switch_order(self.routing_layers1[counter](x_routed)*self.routing_weight,switch_list_applied)
                    else:
                        if self.routing_side_num==None:
                            xr1=self.switch_order(self.routing_layers1[counter](x_routed)*self.routing_weight+self.seq_forward(seq=self.routing_layers1_side[counter],input=x_routed,moe_residual=self.moe_residual)*self.routing_weight,switch_list_applied)
                        else:
                            feature=self.routing_layers1[counter](x_routed)*self.routing_weight
                            list_length=len(self.routing_layers1_side[counter])
                            if self.gate_dims!=None:
                                cur_MOE_weight=self.side1_gating[counter](x_routed)
                                cur_MOE_weight=self.top_k_gating(tensor=cur_MOE_weight,gate_top_k=self.gate_top_k,random_gate_percent=self.random_gate_percent,routing_side_train=self.routing_side_train)
                            for index_temp in range(list_length):
                                routing_side_train_val=1
                                if self.routing_side_train!=None:
                                    routing_side_train_val=self.routing_side_train[index_temp]
                                rout_out_cur=self.seq_forward(seq=self.routing_layers1_side[counter][index_temp],
                                                                input=x_routed,
                                                                moe_residual=self.moe_residual)
                                side_out=rout_out_cur*self.routing_weight*routing_side_train_val
                                if self.gate_dims!=None:
                                    if len(feature.shape)==2:
                                        side_out=side_out*cur_MOE_weight[:,index_temp:index_temp+1]
                                    else:
                                        side_out=side_out*cur_MOE_weight[:,:,index_temp:index_temp+1]
                                feature=feature+side_out
                            xr1=self.switch_order(feature,switch_list_applied)
                    if self.rout_topk!=None:
                        xr1=xr1*block_rout_mask[0,block_rout_index]
                        block_rout_index+=1
                else:
                    xr1=None
                routing1_out[1]=xr1
                if routing1_out[0]!=None:
                    x=x+routing1_out[0]*cur_rout_weight
            if self.routing_val>2:
                routing2_out.pop(0)
                routing2_out.append(None)
                if counter<len(self.cfg.hidden_dims)-2+self.rout_to_output:
                    if self.switch_lists2 is None:
                        switch_list_applied=None
                    else:
                        switch_list_applied=self.switch_lists2[counter]
                    if self.routing_side_depth==0:
                        xrr1=self.switch_order(self.routing_layers2[counter](x_routed)*self.routing_weight,switch_list_applied)
                    else:
                        if self.routing_side_num==None:
                            xrr1=self.switch_order(self.routing_layers2[counter](x_routed)*self.routing_weight + self.seq_forward(seq=self.routing_layers2_side[counter],input=x_routed,moe_residual=self.moe_residual)*self.routing_weight,switch_list_applied)
                        else:
                            feature=self.routing_layers2[counter](x_routed)*self.routing_weight
                            list_length=len(self.routing_layers2_side[counter])
                            if self.gate_dims!=None:
                                cur_MOE_weight=self.side2_gating[counter](x_routed)
                                cur_MOE_weight=self.top_k_gating(tensor=cur_MOE_weight,gate_top_k=self.gate_top_k,random_gate_percent=self.random_gate_percent,routing_side_train=self.routing_side_train)
                            for index_temp in range(list_length):
                                routing_side_train_val=1
                                if self.routing_side_train!=None:
                                    routing_side_train_val=self.routing_side_train[index_temp]
                                rout_out_cur=self.seq_forward(seq=self.routing_layers2_side[counter][index_temp],
                                                                input=x_routed,
                                                                moe_residual=self.moe_residual)
                                side_out=rout_out_cur*self.routing_weight*routing_side_train_val
                                if self.gate_dims!=None:
                                    if len(feature.shape)==2:
                                        side_out=side_out*cur_MOE_weight[:,index_temp:index_temp+1]
                                    else:
                                        side_out=side_out*cur_MOE_weight[:,:,index_temp:index_temp+1]
                                feature=feature+side_out
                            xrr1=self.switch_order(feature,switch_list_applied)
                    if self.rout_topk!=None:
                        xrr1=xrr1*block_rout_mask[0,block_rout_index]
                        block_rout_index+=1
                else:
                    xrr1=None
                routing2_out[2]=xrr1
                if routing2_out[0]!=None:
                    x=x+routing2_out[0]*cur_rout_weight
            if self.routing_val>3:
                routing3_out.pop(0)
                routing3_out.append(None)
                if counter<len(self.cfg.hidden_dims)-3+self.rout_to_output:
                    if self.switch_lists3 is None:
                        switch_list_applied=None
                    else:
                        switch_list_applied=self.switch_lists3[counter]
                    if self.routing_side_depth==0:
                        xrrr1=self.switch_order(self.routing_layers3[counter](x_routed)*self.routing_weight,switch_list_applied)
                    else:
                        if self.routing_side_num==None:
                            xrrr1=self.switch_order(self.routing_layers3[counter](x_routed)*self.routing_weight + self.seq_forward(seq=self.routing_layers3_side[counter],input=x_routed,moe_residual=self.moe_residual)*self.routing_weight,switch_list_applied)
                        else:
                            feature=self.routing_layers3[counter](x_routed)*self.routing_weight
                            list_length=len(self.routing_layers3_side[counter])
                            if self.gate_dims!=None:
                                cur_MOE_weight=self.side3_gating[counter](x_routed)
                                cur_MOE_weight=self.top_k_gating(tensor=cur_MOE_weight,gate_top_k=self.gate_top_k,random_gate_percent=self.random_gate_percent,routing_side_train=self.routing_side_train)
                            for index_temp in range(list_length):
                                routing_side_train_val=1
                                if self.routing_side_train!=None:
                                    routing_side_train_val=self.routing_side_train[index_temp]
                                rout_out_cur=self.seq_forward(seq=self.routing_layers3_side[counter][index_temp],
                                                                input=x_routed,
                                                                moe_residual=self.moe_residual)
                                side_out=rout_out_cur*self.routing_weight*routing_side_train_val
                                if self.gate_dims!=None:
                                    if len(feature.shape)==2:
                                        side_out=side_out*cur_MOE_weight[:,index_temp:index_temp+1]
                                    else:
                                        side_out=side_out*cur_MOE_weight[:,:,index_temp:index_temp+1]
                                feature=feature+side_out
                            xrrr1=self.switch_order(feature,switch_list_applied)
                    if self.rout_topk!=None:
                        xrrr1=xrrr1*block_rout_mask[0,block_rout_index]
                        block_rout_index+=1
                else:
                    xrrr1=None
                routing3_out[3]=xrrr1
                if routing3_out[0]!=None:
                    x=x+routing3_out[0]*cur_rout_weight
            counter+=1
        x = self.layers[-1][0](x)
        if self.rout_to_output==1:
            if self.routing_val>1:
                if routing1_out[0]!=None:
                    routing1_out.pop(0)
                    routing1_out.append(None)
                    x=x+routing1_out[0]
            if self.routing_val>2:
                if routing2_out[0]!=None:
                    routing2_out.pop(0)
                    routing2_out.append(None)
                    x=x+routing2_out[0]
            if self.routing_val>3:
                if routing3_out[0]!=None:
                    routing3_out.pop(0)
                    routing3_out.append(None)
                    x=x+routing3_out[0]
        if reshaped!=False:
            x=torch.reshape(x,[reshaped[0],reshaped[1],x.shape[1]])
        return x
