using LinearAlgebra
using Transformers.HuggingFace
import Transformers.HuggingFace: load_model, joinname, getweight, weight_init, one_init, zero_init
using Transformers.Layers
using Transformers.Layers: CompositeEmbedding, Seq2Seq, MultiheadQKVAttenOp, CausalMultiheadQKVAttenOp, WithScore

function HuggingFace.load_model(_type::Type{<:TGBart}, cfg, state_dict, prefix)
    embed = load_model(_type, CompositeEmbedding, cfg, state_dict, joinname(prefix, "embed"))
    seq2seq = load_model(_type, Seq2Seq, cfg, state_dict, joinname(prefix, "seq2seq"))
    # stage1
    _stage1 = load_model(_type, TGStage1, cfg, state_dict, joinname(prefix, "stage1"))
    ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "stage1.label_proj.layers.0"))
    embedding = embed.label_embed.embeddings
    scale = convert(eltype(embedding), inv(sqrt(size(embedding, 1))))
    label_proj = Layers.Chain(ln, Layers.EmbedDecoder(Embed(scale, embedding)))
    stage1 = TGStage1(_stage1.block, label_proj)

    # stage2
    stage2 = load_model(_type, TGStage2, cfg, state_dict, joinname(prefix, "stage2"))
    return TGBart(embed, seq2seq, stage1, stage2)
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:TGStage1}, cfg, state_dict, prefix)
    p = Float64(cfg[:dropout]); p = iszero(p) ? nothing : p
    pre_norm = cfg[:pre_norm]
    NormBlock = pre_norm ? Layers.PreNormResidual : Layers.PostNormResidual
    sa = load_model(_type, Layers.SelfAttention{CausalMultiheadQKVAttenOp},
                    cfg, state_dict, joinname(prefix, "attention.layer"))
    sa_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "attention.norm"))
    sa = IDPreNormResidual(Layers.DropoutLayer(sa, p), sa_ln)
    ff = load_model(_type, Layers.Chain{Tuple{Layers.Dense, Layers.Dense}},
                    cfg, state_dict, joinname(prefix, "feedforward.layer"))
    ff_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "feedforward.norm"))
    ff = NormBlock(Layers.DropoutLayer(ff, p), ff_ln)
    block = TransformerBlock(sa, ff)
    return TGStage1(block, nothing)
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:TGStage2}, cfg, state_dict, prefix)
    p = Float64(cfg[:dropout]); p = iszero(p) ? nothing : p
    pre_norm = cfg[:pre_norm]
    NormBlock = pre_norm ? Layers.PreNormResidual : Layers.PostNormResidual
    sa = load_model(_type, Layers.SelfAttention{CausalMultiheadQKVAttenOp},
                    cfg, state_dict, joinname(prefix, "attention.layer"))
    sa_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "attention.norm"))
    sa = NormBlock(Layers.DropoutLayer(sa, p), sa_ln)
    ff = load_model(_type, Layers.Chain{Tuple{Layers.Dense, Layers.Dense}},
                    cfg, state_dict, joinname(prefix, "feedforward.layer"))
    ff_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "feedforward.norm"))
    ff = NormBlock(Layers.DropoutLayer(ff, p), ff_ln)
    block = TransformerBlock(sa, ff)

    dims = cfg[:d_model]
    factor = Float32(cfg[:init_std])
    act_str = Symbol(cfg[:proj_head_act])
    act = act_str == :tanh ? NNlib.tanh_fast : HuggingFace.ACT2FN[act_str]
    proj_bias = cfg[:proj_head_bias]
    wi_init = weight_init(dims, 2dims, factor / sqrt(dims))
    wo_init = weight_init(2dims, dims, factor / sqrt(2dims))
    b_init = zero_init(2dims)
    bo_init = zero_init(dims)
    wi_weight = getweight(wi_init, Array, state_dict, joinname(prefix, "head_tail_proj.layers.0.weight"))
    head_wo_weight = getweight(wo_init, Array, state_dict,
                               joinname(prefix, "head_tail_proj.layers.1.layers.0.weight"))
    tail_wo_weight = getweight(wo_init, Array, state_dict,
                               joinname(prefix, "head_tail_proj.layers.1.layers.1.weight"))
    if proj_bias
        wi_bias = getweight(b_init, Array, state_dict, joinname(prefix, "head_tail_proj.layers.0.bias"))
        head_wo_bias = getweight(bo_init, Array, state_dict,
                                 joinname(prefix, "head_tail_proj.layers.1.layers.0.bias"))
        tail_wo_bias = getweight(bo_init, Array, state_dict,
                                 joinname(prefix, "head_tail_proj.layers.1.layers.1.bias"))
    else
        wi_bias = nothing
        head_wo_bias = nothing
        tail_wo_bias = nothing
    end
    proj = Layers.Chain(
        Layers.DropoutLayer(Layers.Dense(act, wi_weight, wi_bias), p),
        Layers.Fork(Layers.Dense(head_wo_weight, head_wo_bias), Layers.Dense(tail_wo_weight, tail_wo_bias)))
    return TGStage2(block, proj)
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:CompositeEmbedding}, cfg, state_dict, prefix)
    vocab_size, dims, factor = cfg[:vocab_size], cfg[:d_model], Float32(cfg[:init_std])
    n_special = cfg[:num_special_tokens]
    use_t5_emb_init = get(cfg, :t5_emb_init, false)
    label_weight = getweight(Layers.Embed, state_dict, joinname(prefix, "label_embed.weight")) do
        if use_t5_emb_init
            t5_embeddings = HuggingFace.Pickle.Torch.THload(joinpath(@__DIR__, "t5_embeddings.bin")) |> adjoint
            weight = t5_embeddings[:, 1:vocab_size]
            weight[:, (end - n_special + 1):end] .= weight_init(n_special, dims, factor)()
        else
            weight = weight_init(vocab_size, dims, factor)()
        end
        return weight
    end
    type_weight = getweight(weight_init(2, dims), Layers.Embed, state_dict, joinname(prefix, "type_embed.weight"))
    proj_init = weight_init(dims, dims)
    bias_init = zero_init(dims)
    pos_weight = getweight(proj_init, Array, state_dict, joinname(prefix, "pos_proj.weight"))
    pos_bias = getweight(bias_init, Array, state_dict, joinname(prefix, "pos_proj.bias"))
    segment_weight = getweight(proj_init, Array, state_dict, joinname(prefix, "segment_proj.weight"))
    segment_bias = getweight(bias_init, Array, state_dict, joinname(prefix, "segment_proj.bias"))
    head_weight = getweight(proj_init, Array, state_dict, joinname(prefix, "head_proj.weight"))
    head_bias = getweight(bias_init, Array, state_dict, joinname(prefix, "head_proj.bias"))
    tail_weight = getweight(proj_init, Array, state_dict, joinname(prefix, "tail_proj.weight"))
    tail_bias = getweight(bias_init, Array, state_dict, joinname(prefix, "tail_proj.bias"))
    ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "ln"))
    label_embed = Embed(label_weight)
    type_embed = Embed(type_weight)
    id_embed = SinCosPositionEmbed(nlog, dims, true)
    pos_proj = Layers.Dense(pos_weight, pos_bias)
    segment_proj = Layers.Dense(segment_weight, segment_bias)
    head_proj = Layers.Dense(head_weight, head_bias)
    tail_proj = Layers.Dense(tail_weight, tail_bias)
    return TokenEmbedding(label_embed, type_embed, id_embed, pos_proj, segment_proj, head_proj, tail_proj, ln)
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:Seq2Seq}, cfg, state_dict, prefix)
    encoder = load_model(_type, TransformerBlock, cfg, state_dict, joinname(prefix, "encoder"))
    decoder = load_model(_type, TransformerDecoderBlock, cfg, state_dict, joinname(prefix, "decoder"))
    return Seq2Seq(encoder, decoder)
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:Layers.LayerNorm}, cfg, state_dict, prefix)
    dims = cfg[:d_model]
    ln_ϵ = Float32(cfg[:layer_norm_eps])
    old_weight_name = joinname(prefix, "gamma")
    old_bias_name = joinname(prefix, "beta")
    weight_name = haskey(state_dict, old_weight_name) ? old_weight_name : joinname(prefix, "weight")
    bias_name = haskey(state_dict, old_bias_name) ? old_bias_name : joinname(prefix, "bias")
    ln_weight = getweight(one_init(dims), Array, state_dict, weight_name)
    ln_bias = getweight(zero_init(dims), Array, state_dict, bias_name)
    return Layers.LayerNorm(ln_weight, ln_bias, ln_ϵ)
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:Layers.SelfAttention{A}}, cfg, state_dict, prefix) where {A <: Union{MultiheadQKVAttenOp, CausalMultiheadQKVAttenOp}}
    head, dims, head_dims = cfg[:num_attention_heads], cfg[:d_model], cfg[:d_head]
    p = cfg[:attention_dropout]; p = iszero(p) ? nothing : p
    use_bias = cfg[:use_bias]
    return_score = cfg[:output_attentions]
    factor = Float32(cfg[:init_std])
    attn_dims = head*head_dims
    b_init = zero_init(3attn_dims)
    qkv_weight = getweight(Array, state_dict, joinname(prefix, "qkv_proj.weight")) do
        q_init = weight_init(dims, attn_dims, factor / sqrt(dims * head_dims))
        kv_init = weight_init(dims, attn_dims, factor / sqrt(dims))
        weight = vcat(q_init(), kv_init(), kv_init()); @assert size(weight) == (3attn_dims, dims)
        return weight
    end
    o_weight = getweight(weight_init(attn_dims, dims, factor / sqrt(dims * head_dims)),
                         Array, state_dict, joinname(prefix, "o_proj.weight"))
    if use_bias
        qkv_bias = getweight(b_init, Array, state_dict, joinname(prefix, "qkv_proj.bias"))
        o_bias = getweight(zero_init(dims), Array, state_dict, joinname(prefix, "o_proj.bias"))
    else
        qkv_bias = nothing
        o_bias = nothing
    end
    qkv_proj = Layers.NSplit(3, Layers.Dense(qkv_weight, qkv_bias))
    o_proj = Layers.Dense(o_weight, o_bias)
    if A <: MultiheadQKVAttenOp
        op = MultiheadQKVAttenOp(head, p)
    else
        op = CausalMultiheadQKVAttenOp(head, p)
    end
    return_score && (op = WithScore(op))
    return Layers.SelfAttention(op, qkv_proj, o_proj)
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:Layers.CrossAttention}, cfg, state_dict, prefix)
    head, dims, head_dims = cfg[:num_cross_attention_heads], cfg[:d_model], cfg[:d_cross_head]
    p = cfg[:attention_dropout]; p = iszero(p) ? nothing : p
    return_score = cfg[:output_attentions]
    use_bias = cfg[:use_bias]
    factor = Float32(cfg[:init_std])
    attn_dims = head*head_dims
    q_weight = getweight(weight_init(dims, attn_dims, factor / sqrt(dims * head_dims)), Array,
                         state_dict, joinname(prefix, "q_proj.weight"))
    kv_weight = getweight(Array, state_dict, joinname(prefix, "kv_proj.weight")) do
        kv_init = weight_init(dims, attn_dims, factor / sqrt(dims))
        weight = vcat(kv_init(), kv_init()); @assert size(weight) == (2attn_dims, dims)
        return weight
    end
    o_weight = getweight(weight_init(attn_dims, dims, factor / sqrt(dims * head_dims)), Array,
                         state_dict, joinname(prefix, "o_proj.weight"))
    if use_bias
        q_bias = getweight(zero_init(attn_dims), Array, state_dict, joinname(prefix, "q_proj.bias"))
        kv_bias = getweight(zero_init(2attn_dims), Array, state_dict, joinname(prefix, "kv_proj.bias"))
        o_bias = getweight(zero_init(dims), Array, state_dict, joinname(prefix, "o_proj.bias"))
    else
        q_bias = kv_bias = o_bias = nothing
    end
    q_proj = Layers.Dense(q_weight, q_bias)
    kv_proj = Layers.NSplit(2, Layers.Dense(kv_weight, kv_bias))
    o_proj = Layers.Dense(o_weight, o_bias)
    op = MultiheadQKVAttenOp(head, p)
    return_score && (op = WithScore(op))
    return Layers.CrossAttention(op, q_proj, kv_proj, o_proj)
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:Layers.Chain{Tuple{Layers.Dense, Layers.Dense}}},
                                cfg, state_dict, prefix)
    dims, ff_dims = cfg[:d_model], cfg[:d_ffn]
    factor = Float32(cfg[:init_std])
    use_bias = cfg[:use_bias]
    act = HuggingFace.ACT2FN[Symbol(cfg[:act])]
    wi_weight = getweight(weight_init(dims, ff_dims, factor / sqrt(dims)), Array,
                          state_dict, joinname(prefix, "layers.0.weight"))
    wo_weight = getweight(weight_init(ff_dims, dims, factor / sqrt(ff_dims)), Array,
                          state_dict, joinname(prefix, "layers.1.weight"))
    if use_bias
        wi_bias = getweight(zero_init(ff_dims), Array, state_dict, joinname(prefix, "layers.0.bias"))
        wo_bias = getweight(zero_init(dims), Array, state_dict, joinname(prefix, "layers.1.bias"))
    else
        wi_bias = nothing
        wo_bias = nothing
    end
    return Layers.Chain(Layers.Dense(act, wi_weight, wi_bias), Layers.Dense(wo_weight, wo_bias))
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:TransformerBlock}, cfg, state_dict, prefix)
    n = cfg[:encoder_layers]
    p = Float64(cfg[:dropout]); p = iszero(p) ? nothing : p
    pre_norm = cfg[:pre_norm]
    collect_output = cfg[:output_attentions] || cfg[:output_hidden_states]
    NormBlock = pre_norm ? Layers.PreNormResidual : Layers.PostNormResidual
    share_weight, num_shared = cfg[:share_weight], cfg[:num_shared]
    blocks = []
    for i = 1:n
        # lprefix = joinname(prefix, :layers, 0, :blocks, i-1)
        lprefix = joinname(prefix, :blocks, i-1)
        sa = load_model(_type, Layers.SelfAttention{MultiheadQKVAttenOp},
                        cfg, state_dict, joinname(lprefix, "attention.layer"))
        sa_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(lprefix, "attention.norm"))
        sa = NormBlock(Layers.DropoutLayer(sa, p), sa_ln)
        ff = load_model(_type, Layers.Chain{Tuple{Layers.Dense, Layers.Dense}},
                        cfg, state_dict, joinname(lprefix, "feedforward.layer"))
        ff_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(lprefix, "feedforward.norm"))
        ff = NormBlock(Layers.DropoutLayer(ff, p), ff_ln)
        block = TransformerBlock(sa, ff)
        push!(blocks, block)
        if i <= num_shared && share_weight
            HuggingFace.get_state_dict(TGBart, block, state_dict, lprefix)
        end
    end
    collect_f = collect_output ? Layers.collect_outputs : nothing
    trf = Transformer(Tuple(blocks), collect_f)
    # final_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "layers.1"))
    # return Layers.Chain(trf, Layers.DropoutLayer(final_ln, p))
    return trf
end

function HuggingFace.load_model(_type::Type{<:TGBart}, ::Type{<:TransformerDecoderBlock}, cfg, state_dict, prefix)
    n = cfg[:decoder_layers]
    p = Float64(cfg[:dropout]); p = iszero(p) ? nothing : p
    pre_norm = cfg[:pre_norm]
    collect_output = cfg[:output_attentions] || cfg[:output_hidden_states]
    NormBlock = pre_norm ? Layers.PreNormResidual : Layers.PostNormResidual
    share_weight, num_shared = cfg[:share_weight], cfg[:num_shared]
    blocks = []
    for i = 1:n
        # lprefix = joinname(prefix, :layers, 0, :blocks, i-1)
        lprefix = joinname(prefix, :blocks, i-1)
        sprefix = (i <= num_shared && share_weight) ? replace(lprefix, "decoder"=>"encoder") : lprefix
        sa = load_model(_type, Layers.SelfAttention{CausalMultiheadQKVAttenOp},
                        cfg, state_dict, joinname(sprefix, "attention.layer"))
        sa_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(sprefix, "attention.norm"))
        sa = NormBlock(Layers.DropoutLayer(sa, p), sa_ln)
        ca = load_model(_type, Layers.CrossAttention, cfg, state_dict, joinname(lprefix, "crossattention.layer"))
        ca_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(lprefix, "crossattention.norm"))
        ca = NormBlock(Layers.DropoutLayer(ca, p), ca_ln)
        ff = load_model(_type, Layers.Chain{Tuple{Layers.Dense, Layers.Dense}},
                        cfg, state_dict, joinname(sprefix, "feedforward.layer"))
        ff_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(sprefix, "feedforward.norm"))
        ff = NormBlock(Layers.DropoutLayer(ff, p), ff_ln)
        block = TransformerDecoderBlock(sa, ca, ff)
        push!(blocks, block)
    end
    collect_f = collect_output ? Layers.collect_outputs : nothing
    trf = Transformer(Tuple(blocks), collect_f)
    # final_ln = load_model(_type, Layers.LayerNorm, cfg, state_dict, joinname(prefix, "layers.1"))
    # return Layers.Chain(trf, Layers.DropoutLayer(final_ln, p))
    return trf
end

function HuggingFace.get_state_dict(m::TGBart, state_dict, prefix)
    HuggingFace.get_state_dict(TGBart, m.embed, state_dict, joinname(prefix, "embed"))
    HuggingFace.get_state_dict(TGBart, m.seq2seq, state_dict, joinname(prefix, "seq2seq"))
    HuggingFace.get_state_dict(TGBart, m.stage1, state_dict, joinname(prefix, "stage1"))
    HuggingFace.get_state_dict(TGBart, m.stage2, state_dict, joinname(prefix, "stage2"))
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::Layers.Fork, state_dict, prefix)
    for (i, layer) in enumerate(m.layers)
        HuggingFace.get_state_dict(p, layer, state_dict, joinname(prefix, :layers, i-1))
    end
    return state_dict
end

HuggingFace.get_state_dict(p::Type{<:TGBart}, m::Layers.NSplit, state_dict, prefix) =
    HuggingFace.get_state_dict(p, m.layer, state_dict, prefix)

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::TGStage1, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.block, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.label_proj, state_dict, joinname(prefix, "label_proj"))
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::TGStage2, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.block, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.head_tail_proj, state_dict, joinname(prefix, "head_tail_proj"))
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::Layers.Chain, state_dict, prefix)
    for (i, layer) in enumerate(m.layers)
        HuggingFace.get_state_dict(p, layer, state_dict, joinname(prefix, :layers, i-1))
    end
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::TokenEmbedding, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.label_embed, state_dict, joinname(prefix, "label_embed"))
    HuggingFace.get_state_dict(p, m.type_embed, state_dict, joinname(prefix, "type_embed"))
    HuggingFace.get_state_dict(p, m.pos_proj, state_dict, joinname(prefix, "pos_proj"))
    HuggingFace.get_state_dict(p, m.segment_proj, state_dict, joinname(prefix, "segment_proj"))
    HuggingFace.get_state_dict(p, m.head_proj, state_dict, joinname(prefix, "head_proj"))
    HuggingFace.get_state_dict(p, m.tail_proj, state_dict, joinname(prefix, "tail_proj"))
    HuggingFace.get_state_dict(p, m.ln, state_dict, joinname(prefix, "ln"))
    return state_dict
end
HuggingFace.get_state_dict(p::Type{<:TGBart}, m::Layers.EmbedDecoder, state_dict, prefix) = state_dict

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::Seq2Seq, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.encoder, state_dict, joinname(prefix, "encoder"))
    HuggingFace.get_state_dict(p, m.decoder, state_dict, joinname(prefix, "decoder"))
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::Layers.SelfAttention, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.qkv_proj.layer, state_dict, joinname(prefix, "qkv_proj"))
    HuggingFace.get_state_dict(p, m.o_proj, state_dict, joinname(prefix, "o_proj"))
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::Layers.CrossAttention, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.q_proj, state_dict, joinname(prefix, "q_proj"))
    HuggingFace.get_state_dict(p, m.kv_proj.layer, state_dict, joinname(prefix, "kv_proj"))
    HuggingFace.get_state_dict(p, m.o_proj, state_dict, joinname(prefix, "o_proj"))
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::TransformerBlock, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.attention.layer, state_dict, joinname(prefix, "attention.layer"))
    HuggingFace.get_state_dict(p, m.attention.norm, state_dict, joinname(prefix, "attention.norm"))
    HuggingFace.get_state_dict(p, m.feedforward.layer, state_dict, joinname(prefix, "feedforward.layer"))
    HuggingFace.get_state_dict(p, m.feedforward.norm, state_dict, joinname(prefix, "feedforward.norm"))
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::TransformerDecoderBlock, state_dict, prefix)
    HuggingFace.get_state_dict(p, m.attention.layer, state_dict, joinname(prefix, "attention.layer"))
    HuggingFace.get_state_dict(p, m.attention.norm, state_dict, joinname(prefix, "attention.norm"))
    HuggingFace.get_state_dict(p, m.crossattention.layer, state_dict, joinname(prefix, "crossattention.layer"))
    HuggingFace.get_state_dict(p, m.crossattention.norm, state_dict, joinname(prefix, "crossattention.norm"))
    HuggingFace.get_state_dict(p, m.feedforward.layer, state_dict, joinname(prefix, "feedforward.layer"))
    HuggingFace.get_state_dict(p, m.feedforward.norm, state_dict, joinname(prefix, "feedforward.norm"))
    return state_dict
end

function HuggingFace.get_state_dict(p::Type{<:TGBart}, m::Transformer, state_dict, prefix)
    for (i, t) in enumerate(m.blocks)
        HuggingFace.get_state_dict(p, t, state_dict, joinname(prefix, :blocks, i-1))
    end
    return state_dict
end
