from enum import Enum
from dataclasses import dataclass
from functools import partial
import numpy as np
import torch
from typing import Union, List


_NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"


class ParamType(Enum):
    LinearWeight = partial(lambda w: w.transpose(-1, -2))
    LinearWeightMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2))
    LinearMHAOutputWeight = partial(
        lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
    )
    LinearBiasMHA = partial(lambda w: w.reshape(*w.shape[:-2], -1))
    LinearWeightOPM = partial(
        lambda w: w.reshape(*w.shape[:-3], -1, w.shape[-1]).transpose(-1, -2)
    )
    Other = partial(lambda w: w)

    def __init__(self, fn):
        self.transformation = fn


@dataclass
class Param:
    param: Union[torch.Tensor, List[torch.Tensor]]
    param_type: ParamType = ParamType.Other
    stacked: bool = False


def _process_translations_dict(d, top_layer=True):
    flat = {}
    for k, v in d.items():
        if type(v) == dict:
            prefix = _NPZ_KEY_PREFIX if top_layer else ""
            sub_flat = {
                (prefix + "/".join([k, k_prime])): v_prime
                for k_prime, v_prime in _process_translations_dict(
                    v, top_layer=False
                ).items()
            }
            flat.update(sub_flat)
        else:
            k = "/" + k if not top_layer else k
            flat[k] = v

    return flat


def stacked(param_dict_list, out=None):

    if out is None:
        out = {}
    template = param_dict_list[0]
    for k, _ in template.items():
        v = [d[k] for d in param_dict_list]
        if type(v[0]) is dict:
            out[k] = {}
            stacked(v, out=out[k])
        elif type(v[0]) is Param:
            stacked_param = Param(
                param=[param.param for param in v],
                param_type=v[0].param_type,
                stacked=True,
            )

            out[k] = stacked_param

    return out


def assign(translation_dict, orig_weights):
    for k, param in translation_dict.items():
        with torch.no_grad():
            weights = torch.as_tensor(orig_weights[k])
            ref, param_type = param.param, param.param_type
            if param.stacked:
                weights = torch.unbind(weights, 0)
            else:
                weights = [weights]
                ref = [ref]

            try:
                weights = list(map(param_type.transformation, weights))
                for p, w in zip(ref, weights):
                    p.copy_(w)
            except:
                print(k)
                print(ref[0].shape)
                print(weights[0].shape)
                raise


def import_jax_weights_(model, npz_path, version="model_1"):
    data = np.load(npz_path)

    LinearWeight = lambda l: (Param(l, param_type=ParamType.LinearWeight))

    LinearBias = lambda l: (Param(l))

    LinearWeightMHA = lambda l: (Param(l, param_type=ParamType.LinearWeightMHA))

    LinearBiasMHA = lambda b: (Param(b, param_type=ParamType.LinearBiasMHA))

    LinearWeightOPM = lambda l: (Param(l, param_type=ParamType.LinearWeightOPM))

    LinearParams = lambda l: {
        "weights": LinearWeight(l.weight),
        "bias": LinearBias(l.bias),
    }

    LayerNormParams = lambda l: {
        "scale": Param(l.weight),
        "offset": Param(l.bias),
    }

    AttentionParams = lambda att: {
        "query_w": LinearWeightMHA(att.linear_q.weight),
        "key_w": LinearWeightMHA(att.linear_k.weight),
        "value_w": LinearWeightMHA(att.linear_v.weight),
        "output_w": Param(
            att.linear_o.weight,
            param_type=ParamType.LinearMHAOutputWeight,
        ),
        "output_b": LinearBias(att.linear_o.bias),
    }

    AttentionGatedParams = lambda att: dict(
        **AttentionParams(att),
        **{
            "gating_w": LinearWeightMHA(att.linear_g.weight),
            "gating_b": LinearBiasMHA(att.linear_g.bias),
        },
    )

    GlobalAttentionParams = lambda att: dict(
        AttentionGatedParams(att),
        key_w=LinearWeight(att.linear_k.weight),
        value_w=LinearWeight(att.linear_v.weight),
    )

    TriAttParams = lambda tri_att: {
        "query_norm": LayerNormParams(tri_att.layer_norm),
        "feat_2d_weights": LinearWeight(tri_att.linear.weight),
        "attention": AttentionGatedParams(tri_att.mha),
    }

    TriMulOutParams = lambda tri_mul: {
        "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
        "left_projection": LinearParams(tri_mul.linear_a_p),
        "right_projection": LinearParams(tri_mul.linear_b_p),
        "left_gate": LinearParams(tri_mul.linear_a_g),
        "right_gate": LinearParams(tri_mul.linear_b_g),
        "center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
        "output_projection": LinearParams(tri_mul.linear_z),
        "gating_linear": LinearParams(tri_mul.linear_g),
    }

    TriMulInParams = lambda tri_mul: {
        "layer_norm_input": LayerNormParams(tri_mul.layer_norm_in),
        "left_projection": LinearParams(tri_mul.linear_b_p),
        "right_projection": LinearParams(tri_mul.linear_a_p),
        "left_gate": LinearParams(tri_mul.linear_b_g),
        "right_gate": LinearParams(tri_mul.linear_a_g),
        "center_layer_norm": LayerNormParams(tri_mul.layer_norm_out),
        "output_projection": LinearParams(tri_mul.linear_z),
        "gating_linear": LinearParams(tri_mul.linear_g),
    }

    PairTransitionParams = lambda pt: {
        "input_layer_norm": LayerNormParams(pt.layer_norm),
        "transition1": LinearParams(pt.linear_1),
        "transition2": LinearParams(pt.linear_2),
    }

    MSAAttParams = lambda matt: {
        "query_norm": LayerNormParams(matt.layer_norm_m),
        "attention": AttentionGatedParams(matt.mha),
    }

    MSAColAttParams = lambda matt: {
        "query_norm": LayerNormParams(matt._msa_att.layer_norm_m),
        "attention": AttentionGatedParams(matt._msa_att.mha),
    }

    MSAGlobalAttParams = lambda matt: {
        "query_norm": LayerNormParams(matt.layer_norm_m),
        "attention": GlobalAttentionParams(matt.global_attention),
    }

    MSAAttPairBiasParams = lambda matt: dict(
        **MSAAttParams(matt),
        **{
            "feat_2d_norm": LayerNormParams(matt.layer_norm_z),
            "feat_2d_weights": LinearWeight(matt.linear_z.weight),
        },
    )

    IPAParams = lambda ipa: {
        "q_scalar": LinearParams(ipa.linear_q),
        "kv_scalar": LinearParams(ipa.linear_kv),
        "q_point_local": LinearParams(ipa.linear_q_points),
        "kv_point_local": LinearParams(ipa.linear_kv_points),
        "trainable_point_weights": Param(
            param=ipa.head_weights, param_type=ParamType.Other
        ),
        "attention_2d": LinearParams(ipa.linear_b),
        "output_projection": LinearParams(ipa.linear_out),
    }

    TemplatePairBlockParams = lambda b: {
        "triangle_attention_starting_node": TriAttParams(b.tri_att_start),
        "triangle_attention_ending_node": TriAttParams(b.tri_att_end),
        "triangle_multiplication_outgoing": TriMulOutParams(b.tri_mul_out),
        "triangle_multiplication_incoming": TriMulInParams(b.tri_mul_in),
        "pair_transition": PairTransitionParams(b.pair_transition),
    }

    MSATransitionParams = lambda m: {
        "input_layer_norm": LayerNormParams(m.layer_norm),
        "transition1": LinearParams(m.linear_1),
        "transition2": LinearParams(m.linear_2),
    }

    OuterProductMeanParams = lambda o: {
        "layer_norm_input": LayerNormParams(o.layer_norm),
        "left_projection": LinearParams(o.linear_1),
        "right_projection": LinearParams(o.linear_2),
        "output_w": LinearWeightOPM(o.linear_out.weight),
        "output_b": LinearBias(o.linear_out.bias),
    }

    def EvoformerBlockParams(b, is_extra_msa=False):
        if is_extra_msa:
            col_att_name = "msa_column_global_attention"
            msa_col_att_params = MSAGlobalAttParams(b.msa_att_col)
        else:
            col_att_name = "msa_column_attention"
            msa_col_att_params = MSAColAttParams(b.msa_att_col)

        d = {
            "msa_row_attention_with_pair_bias": MSAAttPairBiasParams(b.msa_att_row),
            col_att_name: msa_col_att_params,
            "msa_transition": MSATransitionParams(b.core.msa_transition),
            "outer_product_mean": OuterProductMeanParams(b.core.outer_product_mean),
            "triangle_multiplication_outgoing": TriMulOutParams(b.core.tri_mul_out),
            "triangle_multiplication_incoming": TriMulInParams(b.core.tri_mul_in),
            "triangle_attention_starting_node": TriAttParams(b.core.tri_att_start),
            "triangle_attention_ending_node": TriAttParams(b.core.tri_att_end),
            "pair_transition": PairTransitionParams(b.core.pair_transition),
        }

        return d

    ExtraMSABlockParams = partial(EvoformerBlockParams, is_extra_msa=True)

    FoldIterationParams = lambda sm: {
        "invariant_point_attention": IPAParams(sm.ipa),
        "attention_layer_norm": LayerNormParams(sm.layer_norm_ipa),
        "transition": LinearParams(sm.transition.layers[0].linear_1),
        "transition_1": LinearParams(sm.transition.layers[0].linear_2),
        "transition_2": LinearParams(sm.transition.layers[0].linear_3),
        "transition_layer_norm": LayerNormParams(sm.transition.layer_norm),
        "affine_update": LinearParams(sm.bb_update.linear),
        "rigid_sidechain": {
            "input_projection": LinearParams(sm.angle_resnet.linear_in),
            "input_projection_1": LinearParams(sm.angle_resnet.linear_initial),
            "resblock1": LinearParams(sm.angle_resnet.layers[0].linear_1),
            "resblock2": LinearParams(sm.angle_resnet.layers[0].linear_2),
            "resblock1_1": LinearParams(sm.angle_resnet.layers[1].linear_1),
            "resblock2_1": LinearParams(sm.angle_resnet.layers[1].linear_2),
            "unnormalized_angles": LinearParams(sm.angle_resnet.linear_out),
        },
    }

    tps_blocks = model.template_pair_stack.blocks
    tps_blocks_params = stacked([TemplatePairBlockParams(b) for b in tps_blocks])

    ems_blocks = model.extra_msa_stack.blocks
    ems_blocks_params = stacked([ExtraMSABlockParams(b) for b in ems_blocks])

    evo_blocks = model.evoformer.blocks
    evo_blocks_params = stacked([EvoformerBlockParams(b) for b in evo_blocks])

    translations = {
        "evoformer": {
            "preprocess_1d": LinearParams(model.input_embedder.linear_tf_m),
            "preprocess_msa": LinearParams(model.input_embedder.linear_msa_m),
            "left_single": LinearParams(model.input_embedder.linear_tf_z_i),
            "right_single": LinearParams(model.input_embedder.linear_tf_z_j),
            "prev_pos_linear": LinearParams(model.recycling_embedder.linear),
            "prev_msa_first_row_norm": LayerNormParams(
                model.recycling_embedder.layer_norm_m
            ),
            "prev_pair_norm": LayerNormParams(model.recycling_embedder.layer_norm_z),
            "pair_activiations": LinearParams(model.input_embedder.linear_relpos),
            "template_embedding": {
                "single_template_embedding": {
                    "embedding2d": LinearParams(model.template_pair_embedder.linear),
                    "template_pair_stack": {
                        "__layer_stack_no_state": tps_blocks_params,
                    },
                    "output_layer_norm": LayerNormParams(
                        model.template_pair_stack.layer_norm
                    ),
                },
                "attention": AttentionParams(model.template_pointwise_att.mha),
            },
            "extra_msa_activations": LinearParams(model.extra_msa_embedder.linear),
            "extra_msa_stack": ems_blocks_params,
            "template_single_embedding": LinearParams(
                model.template_angle_embedder.linear_1
            ),
            "template_projection": LinearParams(model.template_angle_embedder.linear_2),
            "evoformer_iteration": evo_blocks_params,
            "single_activations": LinearParams(model.evoformer.linear),
        },
        "structure_module": {
            "single_layer_norm": LayerNormParams(model.structure_module.layer_norm_s),
            "initial_projection": LinearParams(model.structure_module.linear_in),
            "pair_layer_norm": LayerNormParams(model.structure_module.layer_norm_z),
            "fold_iteration": FoldIterationParams(model.structure_module),
        },
        "predicted_lddt_head": {
            "input_layer_norm": LayerNormParams(model.aux_heads.plddt.layer_norm),
            "act_0": LinearParams(model.aux_heads.plddt.linear_1),
            "act_1": LinearParams(model.aux_heads.plddt.linear_2),
            "logits": LinearParams(model.aux_heads.plddt.linear_3),
        },
        "distogram_head": {
            "half_logits": LinearParams(model.aux_heads.distogram.linear),
        },
        "experimentally_resolved_head": {
            "logits": LinearParams(model.aux_heads.experimentally_resolved.linear),
        },
        "masked_msa_head": {
            "logits": LinearParams(model.aux_heads.masked_msa.linear),
        },
    }

    no_templ = [
        "model_3",
        "model_4",
        "model_5",
        "model_3_ptm",
        "model_4_ptm",
        "model_5_ptm",
    ]
    if version in no_templ:
        evo_dict = translations["evoformer"]
        keys = list(evo_dict.keys())
        for k in keys:
            if "template_" in k:
                evo_dict.pop(k)

    if "_ptm" in version:
        translations["predicted_aligned_error_head"] = {
            "logits": LinearParams(model.aux_heads.tm.linear)
        }

    flat = _process_translations_dict(translations)

    keys = list(data.keys())
    flat_keys = list(flat.keys())
    incorrect = [k for k in flat_keys if k not in keys]
    missing = [k for k in keys if k not in flat_keys]

    assert len(incorrect) == 0

    assign(flat, data)
