import copy
import importlib
import ml_collections as mlc


#. define the value of Inf used in the model
def set_inf(c, inf):
    for k, v in c.items():
        if isinstance(v, mlc.ConfigDict):
            set_inf(v, inf)
        elif k == "inf":
            c[k] = inf


def enforce_config_constraints(config):
    #. get the setting value from a.b.c structure
    def string_to_setting(s):
        path = s.split('.')
        setting = config
        for p in path:
            setting = setting[p]

        return setting

    #. check for mutually exclusive settings
    mutually_exclusive_bools = [
        (
            "model.template.average_templates", 
            "model.template.offload_templates"
        ),
        (
            "globals.use_lma",
            "globals.use_flash",
        ),
    ]

    #. raise error when exclusive settings meet contradictory conditions
    for s1, s2 in mutually_exclusive_bools:
        s1_setting = string_to_setting(s1)
        s2_setting = string_to_setting(s2)
        if(s1_setting and s2_setting):
            raise ValueError(f"Only one of {s1} and {s2} may be set at a time")

def model_config(
    name, 
    train=False, 
    low_prec=False,
    extra = None,
):
    print(name)
    c = copy.deepcopy(config)
    # TRAINING PRESETS
    if name == "initial_training":
        # AF2 Suppl. Table 4, "initial training" setting
        pass
    #. keep ExtraMSAStack but reduce n_recycling to 2
    elif name == "test-training-1":
        # c.data.train.crop_size = 256 default
        # c.data.train.max_extra_msa = 1024  default
        # c.data.predict.max_extra_msa = 1024 default

        c.model.template.enabled = False
        # c.model.extra_msa.enabled = True default

        c.data.common.max_recycling_iters = 3

        c.globals.blocks_per_ckpt = 1
    elif name == "test-training-1-1":
        # c.data.train.crop_size = 256 default
        # c.data.train.max_extra_msa = 1024  default
        # c.data.predict.max_extra_msa = 1024 default

        c.model.template.enabled = False
        c.model.extra_msa.enabled = False

        c.data.common.max_recycling_iters = 3

        c.globals.blocks_per_ckpt = 1
    elif name == "test-training-1-1-fix_before_strcture":
        print("Using test-training-1-1-fix_before_strcture")
        c.data.train.max_extra_msa = 1   #. need to be at least 1 to act as placeholder
        c.data.eval.max_extra_msa = 1
        c.data.predict.max_extra_msa = 1
        c.model.template.enabled = False
        c.model.extra_msa.enabled = True

        c.model.extra_msa.extra_msa_stack.no_extra_msa = True
        c.globals.fix_before_structure = True
    elif name == "model_3_no_extra":
        # AF2 Suppl. Table 5, Model 1.2.1
        c.data.train.max_extra_msa = 1   #. need to be at least 1 to act as placeholder
        c.data.eval.max_extra_msa = 1
        c.data.predict.max_extra_msa = 1
        c.model.template.enabled = False

        c.model.extra_msa.extra_msa_stack.no_extra_msa = True
    elif name == "model_3_less_extra":
        # AF2 Suppl. Table 5, Model 1.2.1
        c.data.train.max_extra_msa = 128
        c.data.eval.max_extra_msa = 128
        c.data.predict.max_extra_msa = 128
        c.model.template.enabled = False

        c.model.extra_msa.extra_msa_stack.no_extra_msa = False

    elif name == "test-training-1-2":
        # c.data.train.crop_size = 256 default
        # c.data.train.max_extra_msa = 1024  default
        # c.data.predict.max_extra_msa = 1024 default

        c.model.template.enabled = False
        # c.model.extra_msa.enabled = True default

        c.data.common.max_recycling_iters = 1

        c.globals.blocks_per_ckpt = 1
    #. shut down ExtraMSAStack
    # NOTE need to also implement a version that only keeps pair activation in the ExtraMSAStack
    elif name == "test-training-2":
        # c.data.train.crop_size = 256 default
        # c.data.train.max_extra_msa = 1024  default
        # c.data.predict.max_extra_msa = 1024 default

        c.model.template.enabled = False
        c.model.extra_msa.enabled = False

        # c.data.common.max_recycling_iters = 3 default

        config.globals.blocks_per_ckpt = None
    elif name == "finetuning":
        # AF2 Suppl. Table 4, "finetuning" setting
        c.data.train.crop_size = 384
        c.data.train.max_extra_msa = 5120
        c.data.train.max_msa_clusters = 512
        c.loss.violation.weight = 1.
        c.loss.experimentally_resolved.weight = 0.01
    elif name == "model_3":
        # AF2 Suppl. Table 5, Model 1.2.1
        c.data.train.max_extra_msa = 5120
        c.data.predict.max_extra_msa = 5120
        c.model.template.enabled = False
    elif name == 'model_3_msa_1024':
        # AF2 Suppl. Table 5, Model 1.2.1
        c.data.train.max_extra_msa = 1024
        c.data.predict.max_extra_msa = 1024
        c.model.template.enabled = False
    elif name == 'model_3_msa_None':
        c.data.train.max_extra_msa = None  # NOTE decide later
        c.data.predict.max_extra_msa = None # NOTE decide later
        c.model.template.enabled = False
    elif name == "model_3_no_tri_1":
        c.data.train.max_extra_msa = 5120
        c.data.predict.max_extra_msa = 5120
        c.model.template.enabled = False

        c.model.evoformer_stack.no_triangular_attention = True
        c.model.evoformer_stack.no_triangular_multiplication = True
    
    if train:
        c.globals.blocks_per_ckpt = 1 if c.globals.blocks_per_ckpt is None else c.globals.blocks_per_ckpt
        c.globals.chunk_size = None
        c.globals.use_lma = False
        c.globals.offload_inference = False
    c.model.template.average_templates = False
    c.model.template.offload_templates = False

    if extra is not None:
        if type(extra) == str:
            extra = [extra]

        if "preprocessed_msa" in extra:
            # NOTE pre-shuffle MSA in dataset
            c.data.common.sample_msa.pre_shuffled = True
            # c.data.common.masked_msa.skip = True
            c.data.predict.masked_msa_replace_fraction = 0.
            c.data.eval.masked_msa_replace_fraction = 0.
            c.data.train.masked_msa_replace_fraction = 0.

    # if 'evoformer_embedding_test' in extra:
    #     del c.data.common.feat.atom14_alt_gt_exists
    #     del c.data.common.feat.atom14_alt_gt_positions
    #     del c.data.common.feat.atom14_atom_exists
    #     del c.data.common.feat.atom14_atom_is_ambiguous
    #     del c.data.common.feat.atom14_gt_exists
    #     del c.data.common.feat.atom14_gt_positions
    #     del c.data.common.feat.atom37_atom_exists
    #     del c.data.common.feat.backbone_rigid_mask
    
    if low_prec:
        c.globals.eps = 1e-4
        # If we want exact numerical parity with the original, inf can't be
        # a global constant
        set_inf(c, 1e4)

    enforce_config_constraints(c)

    return c


c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
tm_enabled = mlc.FieldReference(False, field_type=bool)
eps = mlc.FieldReference(1e-8, field_type=float)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
tune_chunk_size = mlc.FieldReference(True, field_type=bool)

NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
NUM_EXTRA_SEQ = "extra msa placeholder"
NUM_TEMPLATES = "num templates placeholder"

config = mlc.ConfigDict(
    {
        "experiment":{
            "model": {
                "EvoformerEmbeddingHead":{
                    "c_z": c_z,
                    "no_bins": aux_distogram_bins,
                    "no_heads": 48,
                },
                "EvoformerEmbeddingHead2":{
                    "c_z": c_z,
                    "no_bins": aux_distogram_bins,
                    "no_heads": 48,
                },
                "EvoformerEmbeddingHead3":{
                    "c_z": c_z,
                    "no_bins": aux_distogram_bins,
                    "no_heads": 48,
                    "c_hidden": 128,
                }
            },
            "loss": {
                "EvoformerEmbeddingLoss": {
                    "min_bin": 2.3125,
                    "max_bin": 21.6875,
                    "no_bins": 64,
                    "eps": eps,  # 1e-6,
                    "weight": 0.3,
                },
            },
        },
        "data": {
            "common": {
                "feat": {
                    "aatype": [NUM_RES],
                    "all_atom_mask": [NUM_RES, None],
                    "all_atom_positions": [NUM_RES, None, None],
                    "alt_chi_angles": [NUM_RES, None],
                    "atom14_alt_gt_exists": [NUM_RES, None],
                    "atom14_alt_gt_positions": [NUM_RES, None, None],
                    "atom14_atom_exists": [NUM_RES, None],
                    "atom14_atom_is_ambiguous": [NUM_RES, None],
                    "atom14_gt_exists": [NUM_RES, None],
                    "atom14_gt_positions": [NUM_RES, None, None],
                    "atom37_atom_exists": [NUM_RES, None],
                    "backbone_rigid_mask": [NUM_RES],
                    "backbone_rigid_tensor": [NUM_RES, None, None],
                    "bert_mask": [NUM_MSA_SEQ, NUM_RES],
                    "chi_angles_sin_cos": [NUM_RES, None, None],
                    "chi_mask": [NUM_RES, None],
                    "extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
                    "extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
                    "extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
                    "extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
                    "extra_msa_row_mask": [NUM_EXTRA_SEQ],
                    "is_distillation": [],
                    "msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
                    "msa_mask": [NUM_MSA_SEQ, NUM_RES],
                    "msa_row_mask": [NUM_MSA_SEQ],
                    "no_recycling_iters": [],
                    "pseudo_beta": [NUM_RES, None],
                    "pseudo_beta_mask": [NUM_RES],
                    "residue_index": [NUM_RES],
                    "residx_atom14_to_atom37": [NUM_RES, None],
                    "residx_atom37_to_atom14": [NUM_RES, None],
                    "resolution": [],
                    "rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
                    "rigidgroups_group_exists": [NUM_RES, None],
                    "rigidgroups_group_is_ambiguous": [NUM_RES, None],
                    "rigidgroups_gt_exists": [NUM_RES, None],
                    "rigidgroups_gt_frames": [NUM_RES, None, None, None],
                    "seq_length": [],
                    "seq_mask": [NUM_RES],
                    "target_feat": [NUM_RES, None],
                    "template_aatype": [NUM_TEMPLATES, NUM_RES],
                    "template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
                    "template_all_atom_positions": [
                        NUM_TEMPLATES, NUM_RES, None, None,
                    ],
                    "template_alt_torsion_angles_sin_cos": [
                        NUM_TEMPLATES, NUM_RES, None, None,
                    ],
                    "template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES],
                    "template_backbone_rigid_tensor": [
                        NUM_TEMPLATES, NUM_RES, None, None,
                    ],
                    "template_mask": [NUM_TEMPLATES],
                    "template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
                    "template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
                    "template_sum_probs": [NUM_TEMPLATES, None],
                    "template_torsion_angles_mask": [
                        NUM_TEMPLATES, NUM_RES, None,
                    ],
                    "template_torsion_angles_sin_cos": [
                        NUM_TEMPLATES, NUM_RES, None, None,
                    ],
                    "true_msa": [NUM_MSA_SEQ, NUM_RES],
                    "use_clamped_fape": [],
                },
                "sample_msa":{
                    "pre_shuffled": False,
                },
                "masked_msa": {
                    "skip": False,
                    "profile_prob": 0.1,
                    "same_prob": 0.1,
                    "uniform_prob": 0.1,
                },
                "max_recycling_iters": 3,
                "msa_cluster_features": True,
                "reduce_msa_clusters_by_max_templates": False,
                "resample_msa_in_recycling": True,
                "template_features": [
                    "template_all_atom_positions",
                    "template_sum_probs",
                    "template_aatype",
                    "template_all_atom_mask",
                ],
                "unsupervised_features": [
                    "aatype",
                    "residue_index",
                    "msa",
                    "num_alignments",
                    "seq_length",
                    "between_segment_residues",
                    "deletion_matrix",
                    "no_recycling_iters",
                ],
                "use_templates": templates_enabled,
                "use_template_torsion_angles": embed_template_torsion_angles,
            },
            "supervised": {
                "clamp_prob": 0.9,
                "supervised_features": [
                    "all_atom_mask",
                    "all_atom_positions",
                    "resolution",
                    "use_clamped_fape",
                    "is_distillation",
                ],
            },
            "predict": {
                "fixed_size": True,
                "subsample_templates": False,  # We want top templates.
                "masked_msa_replace_fraction": 0.15,
                "max_msa_clusters": 512,
                "max_extra_msa": 1024,
                "max_template_hits": 4,
                "max_templates": 4,
                "crop": False,
                "crop_size": None,
                "supervised": False,
                "uniform_recycling": False,
            },
            "eval": {
                "fixed_size": True,
                "subsample_templates": False,  # We want top templates.
                "masked_msa_replace_fraction": 0.15,
                "max_msa_clusters": 128,
                "max_extra_msa": 1024,
                "max_template_hits": 4,
                "max_templates": 4,
                "crop": False,
                "crop_size": None,
                "supervised": True,
                "uniform_recycling": False,
            },
            "train": {
                "fixed_size": True,
                "subsample_templates": True,
                "masked_msa_replace_fraction": 0.15,
                "max_msa_clusters": 128,
                "max_extra_msa": 1024,
                "max_template_hits": 4,
                "max_templates": 4,
                "shuffle_top_k_prefiltered": 20,
                "crop": True,
                "crop_size": 256,
                "supervised": True,
                "clamp_prob": 0.9,
                "max_distillation_msa_clusters": 1000,
                "uniform_recycling": True,
                "distillation_prob": 0.75,
            },
            "data_module": {
                "use_small_bfd": False,
                "data_loaders": {
                    "batch_size": 1,
                    "num_workers": 16,
                    "pin_memory": True,
                },
            },
        },
        # Recurring FieldReferences that can be changed globally here
        "globals": {
            "blocks_per_ckpt": blocks_per_ckpt,
            "chunk_size": chunk_size,
            # Use Staats & Rabe's low-memory attention algorithm. Mutually
            # exclusive with use_flash.
            "use_lma": False,
            # Use FlashAttention in selected modules. Mutually exclusive with 
            # use_lma. Doesn't work that well on long sequences (>1000 residues).
            "use_flash": False,
            "offload_inference": False,
            "c_z": c_z,
            "c_m": c_m,
            "c_t": c_t,
            "c_e": c_e,
            "c_s": c_s,
            "eps": eps,
            "fix_before_structure": False,  #. Added to check the learning of structure module
            "get_evoformer_embedding": False,
            "get_all_evoformer_embedding": False,
            "get_all_structure": False,
        },
        "model": {
            "_mask_trans": False,
            "input_embedder": {
                "tf_dim": 22,
                "msa_dim": 49,
                "c_z": c_z,
                "c_m": c_m,
                "relpos_k": 32,
            },
            "recycling_embedder": {
                "c_z": c_z,
                "c_m": c_m,
                "min_bin": 3.25,
                "max_bin": 20.75,
                "no_bins": 15,
                "inf": 1e8,
            },
            "template": {
                "enabled": templates_enabled,
            },
            "extra_msa": {
                "extra_msa_embedder": {
                    "c_in": 25,
                    "c_out": c_e,
                },
                "extra_msa_stack": {
                    "c_m": c_e,
                    "c_z": c_z,
                    "c_hidden_msa_att": 8,
                    "c_hidden_opm": 32,
                    "c_hidden_mul": 128,
                    "c_hidden_pair_att": 32,
                    "no_heads_msa": 8,
                    "no_heads_pair": 4,
                    "no_blocks": 4,
                    "transition_n": 4,
                    "msa_dropout": 0.15,
                    "pair_dropout": 0.25,
                    "clear_cache_between_blocks": False,
                    "tune_chunk_size": tune_chunk_size,
                    "inf": 1e9,
                    "eps": eps,  # 1e-10,
                    "ckpt": blocks_per_ckpt is not None,
                    'no_extra_msa': False,    #. Added to disable extra_msa
                },
                "enabled": True,
            },
            "evoformer_stack": {
                "c_m": c_m,
                "c_z": c_z,
                "c_hidden_msa_att": 32,
                "c_hidden_opm": 32,
                "c_hidden_mul": 128,
                "c_hidden_pair_att": 32,
                "c_s": c_s,
                "no_heads_msa": 8,
                "no_heads_pair": 4,
                "no_blocks": 48,
                "transition_n": 4,
                "msa_dropout": 0.15,
                "pair_dropout": 0.25,
                "blocks_per_ckpt": blocks_per_ckpt,
                "clear_cache_between_blocks": False,
                "tune_chunk_size": tune_chunk_size,
                "inf": 1e9,
                "eps": eps,  # 1e-10,
                "no_triangular_attention": False,
                "no_triangular_multiplication": False,
            },
            "structure_module": {
                "c_s": c_s,
                "c_z": c_z,
                "c_ipa": 16,
                "c_resnet": 128,
                "no_heads_ipa": 12,
                "no_qk_points": 4,
                "no_v_points": 8,
                "dropout_rate": 0.1,
                "no_blocks": 8,
                "no_transition_layers": 1,
                "no_resnet_blocks": 2,
                "no_angles": 7,
                "trans_scale_factor": 10,
                "epsilon": eps,  # 1e-12,
                "inf": 1e5,
            },
            "heads": {
                "lddt": {
                    "no_bins": 50,
                    "c_in": c_s,
                    "c_hidden": 128,
                },
                "distogram": {
                    "c_z": c_z,
                    "no_bins": aux_distogram_bins,
                },
                "tm": {
                    "c_z": c_z,
                    "no_bins": aux_distogram_bins,
                    "enabled": tm_enabled,
                },
                "masked_msa": {
                    "c_m": c_m,
                    "c_out": 23,
                },
                "experimentally_resolved": {
                    "c_s": c_s,
                    "c_out": 37,
                },
            },
        },
        "relax": {
            "max_iterations": 0,  # no max
            "tolerance": 2.39,
            "stiffness": 10.0,
            "max_outer_iterations": 20,
            "exclude_residues": [],
        },
        "loss": {
            "distogram": {
                "min_bin": 2.3125,
                "max_bin": 21.6875,
                "no_bins": 64,
                "eps": eps,  # 1e-6,
                "weight": 0.3,
            },
            "experimentally_resolved": {
                "eps": eps,  # 1e-8,
                "min_resolution": 0.1,  #. filter out NMR samples
                "max_resolution": 3.0,
                "weight": 0.0,
            },
            "fape": {
                "backbone": {
                    "clamp_distance": 10.0,
                    "loss_unit_distance": 10.0,
                    "weight": 0.5,
                },
                "sidechain": {
                    "clamp_distance": 10.0,
                    "length_scale": 10.0,
                    "weight": 0.5,
                },
                "eps": 1e-4,
                "weight": 1.0,
            },
            "plddt_loss": {
                "min_resolution": 0.1,
                "max_resolution": 3.0,
                "cutoff": 15.0,
                "no_bins": 50,
                "eps": eps,  # 1e-10,
                "weight": 0.01,
            },
            "masked_msa": {
                "eps": eps,  # 1e-8,
                "weight": 2.0,
            },
            "supervised_chi": {
                "chi_weight": 0.5,
                "angle_norm_weight": 0.01,
                "eps": eps,  # 1e-6,
                "weight": 1.0,
            },
            "violation": {
                "violation_tolerance_factor": 12.0,
                "clash_overlap_tolerance": 1.5,
                "eps": eps,  # 1e-6,
                "weight": 0.0,
            },
            "tm": {
                "max_bin": 31,
                "no_bins": 64,
                "min_resolution": 0.1,
                "max_resolution": 3.0,
                "eps": eps,  # 1e-8,
                "weight": 0.,
                "enabled": tm_enabled,
            },
            "eps": eps,
            "cum_loss_scale": 1.0,
        },
        "ema": {"decay": 0.999},
    }
)
