from neural_tangents_v1 import stax
import re
from jax import random

def get_staxNet(structure_str, num_classes=100):

    layers_str = structure_str.split(')')
    layers_str.pop()
    
    model_layers = []
    
    for layer_str in layers_str:
        detail_str = layer_str.split('(')
        name = detail_str[0]
        print(name)
        parameter = detail_str[1].split(',')
        numbers = re.findall(r'\d+', name)
        
        if 'SuperResIDWE' in name:
            for i in range(int(parameter[4])):
                resblock_1 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]) * int(numbers[0]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]) * int(numbers[0]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                )
                resblock_proj_1 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block_1 = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock_1, resblock_proj_1),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block_1)
                
                resblock_2 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]) * int(numbers[0]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]) * int(numbers[0]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                )
                resblock_proj_2 = stax.serial(
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block_2 = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock_2, resblock_proj_2),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block_2)            
        elif 'SuperResK1K' in name:
            for i in range(int(parameter[4])):
                resblock_1 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                )
                resblock_proj_1 = stax.serial(
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block_1 = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock_1, resblock_proj_1),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block_1)
                
                resblock_2 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                )
                resblock_proj_2 = stax.serial(
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block_2 = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock_2, resblock_proj_2),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block_2)      
        elif 'SuperResK' in name:
            for i in range(int(parameter[4])):
                resblock = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(int(numbers[0]), int(numbers[0])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(1,1), padding='SAME'),
                )
                resblock_proj = stax.serial(
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock, resblock_proj),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block)
        elif 'SuperConvK' in name:
            for i in range(int(parameter[3])):
                model_layers.append(stax.Conv(out_chan=int(parameter[1]), filter_shape=(int(numbers[0]), int(numbers[0])), strides=(int(parameter[2]), int(parameter[2]))))
                model_layers.append(stax.Relu())
    model_layers.append(stax.GlobalAvgPool())
    model_layers.append(stax.Flatten())
    model_layers.append(stax.Dense(num_classes))
    return stax.serial(*model_layers)

