import torch
import stork
from mpmath import zeros
from stork.connections import Connection
from stork.initializers import (
    FluctuationDrivenCenteredNormalInitializer,
    DistInitializer,
    Initializer,
)
from stork.layers import Layer

# Custom additions to stork
from .custom.readout import CustomReadoutGroup, AverageReadouts
from .custom.lif import CustomLIFGroup
from .custom.custom_regularizer import LowerBoundL2, UpperBoundL2
from .custom.models import CustomRecurrentSpikingModel
from .custom.custom_connections import (ChannelAttentionConnection_multiHead,
                                        Connection_with_VS_shortcut_withBatchNorm,
                                        Channel1dConvConnectionLayer,
                                        LinearLayer_of_shortcut,
                                        LinearLayer_with_shortcut,
                                        Connection_with_VS_shortcut,
                                        ChannelAttentionConnection,
                                        Connection_withBatchNorm,
                                        Connection_Identity,
                                        Connection_withBatchNorm_with_multi_src,
                                        Connection_withBatchNorm_with_GradientReverse,
                                        ConvConnection_withBatchNorm,
                                        )

from .custom.custom_inputGroup import custom_InputGroup as InputGroup
from .custom.custom_inputGroup import custom_fake_InputGroup, custom_feedback_InputGroup
# from stork.nodes import InputGroup
from torch import nn


# from .data import compute_input_firing_rates
from .data_foundationDemo1 import compute_input_firing_rates


import logging

logger = logging.getLogger(__name__)


def get_regularizers(cfg):
    regs = []
    regLB = stork.regularizers.LowerBoundL2(
        cfg.training.LB_L2_strength, threshold=cfg.training.LB_L2_thresh, dims=False
    )
    regs.append(regLB)
    regUB = stork.regularizers.UpperBoundL2(
        cfg.training.UB_L2_strength, threshold=cfg.training.UB_L2_thresh, dims=1
    )
    regs.append(regUB)
    # regs = []
    # regLB = LowerBoundL2(
    #     cfg.training.LB_L2_strength, threshold=cfg.training.LB_L2_thresh, dims=False
    # )
    # regs.append(regLB)
    # regUB = UpperBoundL2(
    #     cfg.training.UB_L2_strength, threshold=cfg.training.UB_L2_thresh, dims=1
    # )
    # regs.append(regUB)

    return regs


def get_actfn(cfg):
    act_fn = stork.activations.CustomSpike
    if cfg.model.stochastic:
        act_fn.escape_noise_type = "sigmoid"
    else:
        act_fn.escape_noise_type = "step"
    act_fn.escape_noise_params = {"beta": cfg.training.SG_beta}
    act_fn.surrogate_type = "SuperSpike"
    act_fn.surrogate_params = {"beta": cfg.training.SG_beta}

    return act_fn


class ZeroInitializer(Initializer):
    """
    初始化网络连接的权重全部为0。
    """
    def __init__(self):
        super().__init__()

    def _get_weights(self, connection):
        """
        返回一个与连接权重形状相同的全0张量
        """
        shape = connection.op.weight.shape
        return torch.zeros(shape)

    def initialize_connection(self, connection):
        """
        初始化一个连接对象的权重为0
        """
        # 获取全0权重
        weights = self._get_weights(connection)

        # 设置权重和偏置
        self._set_weights_and_bias(connection, weights)

def get_initializers(cfg, nu=None, dtype=torch.float32):

    if nu is None or cfg.initializer.compute_nu is False:
        logger.info("using un-caculated nu")
        nu = cfg.initializer.nu

    if isinstance(nu, list):
        logger.info(f"Initializing with nu = {nu[0]} and {nu[1]}")
        hidden_init = []
        if len(nu) == 2:
            if nu[1]==0:
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=nu[0],
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
                hidden_init.append(ZeroInitializer())
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=nu[0],
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))

            else:
                for i in range(len(nu)):
                    hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                        sigma_u=cfg.initializer.sigma_u,
                        nu=nu[i],
                        timestep=cfg.data.dt,
                        alpha=cfg.initializer.alpha,
                        dtype=dtype
                    ))
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=sum(nu) / len(nu),
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
        elif len(nu)==3:
            if nu[1]==0 or nu[1] is None:
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=nu[0],
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
                hidden_init.append(ZeroInitializer())
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=nu[0],
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
            else:
                for i in range(len(nu)-1):
                    hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                        sigma_u=cfg.initializer.sigma_u,
                        nu=nu[i],
                        timestep=cfg.data.dt,
                        alpha=cfg.initializer.alpha,
                        dtype=dtype
                    ))
                hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                    sigma_u=cfg.initializer.sigma_u,
                    nu=sum(nu[:-1]) / (len(nu)-1),
                    timestep=cfg.data.dt,
                    alpha=cfg.initializer.alpha,
                    dtype=dtype
                ))
            hidden_init.append(FluctuationDrivenCenteredNormalInitializer(
                sigma_u=cfg.initializer.sigma_u,
                nu=nu[2],
                timestep=cfg.data.dt,
                alpha=cfg.initializer.alpha,
                dtype=dtype
            ))
    else:
        logger.info(f"Initializing with nu = {nu}")

        # Initializers
        hidden_init = FluctuationDrivenCenteredNormalInitializer(
            sigma_u=cfg.initializer.sigma_u,
            nu=nu,
            timestep=cfg.data.dt,
            alpha=cfg.initializer.alpha,
            dtype=dtype
        )


    readout_init = DistInitializer(
        dist=torch.distributions.Normal(0, 1), 
        scaling="1/sqrt(k)", 
        dtype=dtype
    )

    return hidden_init, readout_init

def get_snntorch_model(cfg, nb_inputs, dtype, data=None, stateFlag="default"):

    nb_time_steps = int(cfg.data.sample_duration / cfg.data.dt)
    nb_outputs = cfg.data.nb_outputs

    if hasattr(cfg, 'multi_cuda') and cfg.multi_cuda:
        device_= torch.device(cfg.gpu_ids[0])
    else:
        device_= cfg.device

    if stateFlag=="default":
        batchsize_ = cfg.training.batchsize
    elif stateFlag=="pretrain":
        batchsize_ = cfg.training.batchsize_pretrain
    elif stateFlag=="fine-tune":
        batchsize_ = cfg.training.batchsize_finetuning
    model = CustomRecurrentSpikingModel(
        batchsize_,
        nb_time_steps=nb_time_steps,
        nb_inputs=nb_inputs,
        device=device_,
        dtype=dtype,
    )

    # Activation function 根据配置获取激活函数
    act_fn = get_actfn(cfg)

    # Regularizer list 获取正则化器
    regs = get_regularizers(cfg)

    # Compute mean firing rates for initializer 计算输入数据的平均发放率
    if data is not None:
        mean1, mean2 = compute_input_firing_rates(data, cfg, nb_inputs)
    else:
        mean1 = None

    hidden_init, readout_init = get_initializers(cfg, mean1, dtype) #根据配置和输入脉冲频率获取隐藏层和输出层的初始化器

    # INPUT LAYER
    # # # # # # # #

    input_group = model.add_group(InputGroup(nb_inputs,
                                             dropout_p=cfg.model.dropout_p))

    # HIDDEN LAYERS
    # # # # # # # #
    current_src_grp = input_group

    # 定义隐藏层神经元的参数。
    hidden_neuron_kwargs = {
        "tau_mem": cfg.model.tau_mem,
        "tau_syn": cfg.model.tau_syn,
        "activation": act_fn,
        "dropout_p": cfg.model.dropout_p,
        "het_timescales": cfg.model.het_timescales,
        "learn_timescales": cfg.model.learn_timescales,
        "is_delta_syn": cfg.model.delta_synapses,
    }

    # 循环构建隐藏层，每层使用 CustomLIFGroup 作为神经元类型，并根据配置设置参数。初始化隐藏层，并根据输入脉冲频率调整权重。
    for i in range(cfg.model.nb_hidden):

        hidden_layer = Layer(
            name="hidden",
            model=model,
            size=cfg.model.hidden_size[i],
            input_group=current_src_grp,
            recurrent=cfg.model.recurrent[i],
            regs=regs,
            neuron_class=CustomLIFGroup,
            neuron_kwargs=hidden_neuron_kwargs,
            connection_kwargs={},
        )

        current_src_grp = hidden_layer.output_group

        # initialize
        hidden_init.initialize(hidden_layer)

        if i == 0 and nb_inputs == 192 and data is not None:
            with torch.no_grad():
                hidden_layer.connections[0].op.weight[:, 96:] /= mean2 / mean1

    # READOUT LAYER
    # # # # # # # #
    # 构建输出层:

    if cfg.model.multiple_readouts:
        logger.info("Adding custom readout groups")
        custom_readouts = get_custom_readouts(cfg)
        for g in custom_readouts:
            model.add_group(g)
            con_ro = model.add_connection(Connection(current_src_grp, g, dtype=dtype))
            readout_init.initialize(con_ro)

        model.add_group(AverageReadouts(model.groups[-len(custom_readouts) :]))
    else:
        logger.info("Adding single readout group")
        readout_group = model.add_group(
            CustomReadoutGroup(
                nb_outputs,
                tau_mem=cfg.model.tau_mem_readout,
                tau_syn=cfg.model.tau_syn_readout,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=cfg.model.delta_synapses,
            )
        )

        con_ro = model.add_connection(
            Connection(current_src_grp, readout_group, dtype=dtype)
        )

        readout_init.initialize(con_ro)

    return model


def get_custom_readouts(cfg, size=None, name=None):
    if size is None:
        size = cfg.data.nb_outputs
    ro_list = []
    for ro, specs in cfg.model["readouts"].items():
        if "tau_mem" in specs:
            tau_mem = specs["tau_mem"]
        else:
            tau_mem = cfg.model.tau_mem_readout
        if "tau_syn" in specs:
            tau_syn = specs["tau_syn"]
        else:
            tau_syn = cfg.model.tau_syn_readout

        if specs["type"] == "default":
            ro_group = CustomReadoutGroup(
                size,
                tau_mem=tau_mem,
                tau_syn=tau_syn,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=False,
                name=name if name is not None else "Readout",
            )
        elif specs["type"] == "delta":
            ro_group = CustomReadoutGroup(
                size,
                tau_mem=tau_mem,
                tau_syn=tau_syn,
                het_timescales=cfg.model.het_timescales_readout,
                learn_timescales=cfg.model.learn_timescales_readout,
                initial_state=-1e-2,
                is_delta_syn=True,
                name=name if name is not None else "Readout",
            )

        ro_group.set_name(ro)
        ro_list.append(ro_group)

    return ro_list