import jax
from jax import random
from neural_tangents_v1 import stax
import jax.numpy as jnp
import re
import torch

def get_staxNet(structure_str, num_classes=100):

    layers_str = structure_str.split(')')
    layers_str.pop()
    
    model_layers = []
    
    for layer_str in layers_str:
        detail_str = layer_str.split('(')
        name = detail_str[0]
        # print(name)
        parameter = detail_str[1].split(',')
        numbers = re.findall(r'\d+', name)
        
        if 'SuperResIDWE' in name:
            for i in range(int(parameter[4])):
                resblock_1 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]) * int(numbers[0]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]) * int(numbers[0]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                )
                resblock_proj_1 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block_1 = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock_1, resblock_proj_1),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block_1)
                
                resblock_2 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]) * int(numbers[0]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]) * int(numbers[0]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                )
                resblock_proj_2 = stax.serial(
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block_2 = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock_2, resblock_proj_2),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block_2)            
        elif 'SuperResK1K' in name:
            for i in range(int(parameter[4])):
                resblock_1 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                )
                resblock_proj_1 = stax.serial(
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block_1 = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock_1, resblock_proj_1),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block_1)
                
                resblock_2 = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(1, 1), padding='SAME'),
                )
                resblock_proj_2 = stax.serial(
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block_2 = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock_2, resblock_proj_2),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block_2)      
        elif 'SuperResK' in name:
            for i in range(int(parameter[4])):
                resblock = stax.serial(
                    stax.Conv(out_chan=int(parameter[3]), filter_shape=(int(numbers[0]), int(numbers[0])), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                    stax.Relu(),
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(int(numbers[1]), int(numbers[1])), strides=(1,1), padding='SAME'),
                )
                resblock_proj = stax.serial(
                    stax.Conv(out_chan=int(parameter[1]), filter_shape=(1, 1), strides=(int(parameter[2]), int(parameter[2])), padding='SAME'),
                )
                block = stax.serial(
                    stax.FanOut(2),
                    stax.parallel(resblock, resblock_proj),
                    stax.FanInSum(),
                    stax.Relu(),
                )
                model_layers.append(block)
        elif 'SuperConvK' in name:
            for i in range(int(parameter[3])):
                model_layers.append(stax.Conv(out_chan=int(parameter[1]), filter_shape=(int(numbers[0]), int(numbers[0])), strides=(int(parameter[2]), int(parameter[2]))))
                model_layers.append(stax.Relu())
    model_layers.append(stax.GlobalAvgPool())
    model_layers.append(stax.Flatten())
    model_layers.append(stax.Dense(num_classes))
    return stax.serial(*model_layers)


def compute_nas_score(structure_str, gpu, resolution, batch_size):
    # gpu_devices = jax.devices("gpu")
    # device_id = gpu
    # device = gpu_devices[device_id]
    
    model = get_staxNet(structure_str=structure_str)
    # model = jax.device_put(model, device)
    # print(model)
    init_fn, apply_fn, kernel_fn = model

    key = random.PRNGKey(1)
    input = random.normal(key, (batch_size, resolution, resolution, 3))

    _, params = init_fn(key, input_shape=input.shape)
    cntk = kernel_fn(input, input, 'ntk')
    # print(cntk.shape)
    eigenvalues = jnp.linalg.eigvalsh(cntk)
    sorted_eigenvalues = jnp.sort(eigenvalues) # 从小到大

    score = sorted_eigenvalues[-1] / sorted_eigenvalues[0]
    # score = sorted_eigenvalues[-8:].sum() / sorted_eigenvalues.sum()

    score = jnp.real(score)

    model = None

    return -1 * score

