import logging
import copy
import ml_collections as mlc
import torch

logger = logging.getLogger("TopoDiff.config.config")

#. define the value of Inf used in the model
def set_inf(c, inf):
    for k, v in c.items():
        if isinstance(v, mlc.ConfigDict):
            set_inf(v, inf)
        elif k == "inf":
            c[k] = inf

def model_config(
    name, 
):
    c = copy.deepcopy(config)
    
    if name == 'ckpt_neurips_workshop':
        c.Model.Global.Embedder = 'Embedder_v2'
        c.Model.Global.Backbone = 'Backbone_v2'
        c.Model.Global.Encoder = 'Encoder_v1'

        c.Global.c_s = 256
        c.Global.c_z = 128

        # Encoder_v1
        c.Model.Encoder_v1.trainable = True
        c.Model.Encoder_v1.temperature = 1.

        # Embedder_v2
        c.Model.Embedder_v2.topo_embedder.enabled = True
        c.Model.Embedder_v2.topo_embedder.embed_dim = 32
        c.Model.Embedder_v2.topo_embedder.type = 'continuous'
        
        # Backbone_v2
        c.Model.Backbone_v2.reconstruct_CB = False

        # Diffuser
        c.Model.Diffuser.SO3.hat_and_noise.noise_scale = (0, 5)

    else:
        raise ValueError(f'Unknown model name {name}')

    return c

T = mlc.FieldReference(200, field_type=int)
c_z = mlc.FieldReference(128, field_type=int)
c_m = mlc.FieldReference(256, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)

config = mlc.ConfigDict(
    {
        'Global': {
            'T' : T,
            "c_z": c_z,
            "c_m": c_m,
            "c_s": c_s,
            "eps": eps,
        },
        'Model':{
            'Global': {
                'infer_no_recyc': False,
                'Embedder': 'Embedder_v1',
                'Backbone': 'Backbone_v1',
                'Diffuser': 'Diffuser_v1',
                'Encoder' : None,
            },
            'Encoder_v1': {
                'feature_dim': 68,
                'hidden_dim': 128,
                'dropout': 0,
                'dropout_final': 0,
                'eps': eps,

                'n_layers': 6,
                'm_dim': 16,
                'hidden_egnn_dim': 64,
                'hidden_edge_dim': 256,

                'embedding_size': 64,
                'latent_dim': 32,

                'normalize': False,
                'final_init': True,
                'reduce': 'sum',

                'transformer': {
                    'enable': True,
                    'version': 2,
                    'n_heads': 4,
                    'n_layers': 2,
                    'dropout': 0,
                },

                'trainable': True,
                'eps': eps,
                'temperature': 1.,
            },
            'Embedder_v2': {
                'c_s': c_s,
                'c_z': c_z,

                'tf_dim': 22,
                'pos_emb_dim': 32,
                'time_emb_dim': 32,
                'embed_fixed': True,

                'time_emb_max_positions': 10000,
                'pos_emb_max_positions': 2056,

                'recyc_struct': True,
                'recyc_struct_min_bin': 3.25,
                'recyc_struct_max_bin': 20.75,
                'recyc_struct_no_bin': 15,

                'eps': eps,
                'inf': 1e5,

                'msa_embedder': {
                    'enabled': False,
                    'sepaeate_load': True,
                    'pretrained_model_path' : None,
                    'src_model_name' : None,
                    'c_m': 256,
                    'c_z': 128,
                },

                'topo_embedder': {
                    'enabled': False,
                    'type': 'category',
                    'embed_dim': None,
                    'num_class': None,
                },
            },
            'Embedder': {
                'Global': {
                    'Position_embedder': True,
                    'Timestep_embedder': True,
                    'Recycling_embedder': True,  # self-conditioning
                },
                'Recycling_embedder':{
                    "c_z": c_z,
                    "c_m": c_m,
                    "min_bin": 3.25,
                    "max_bin": 20.75,
                    "no_bins": 15,
                    "inf": 1e8,
                    'recyc_seq': True,
                    'recyc_pair': True,
                    'recyc_struct': True,
                },
                'Position_embedder': {
                    'tf_dim': 22,
                    'c_z': c_z,
                    'c_m': c_m,
                    'relpos_k': 32,
                },  
                'Timestep_embedder': {
                    'c_z': c_z,
                    'c_m': c_m,
                    'emb_dim': 16,
                    'max_positions': 10000
                },
            },
            'Backbone_v2': {
                'c_s': c_s,
                'c_z': c_z,
                'c_skip': 64,
                'no_blocks': 4,
                'seq_tfmr_num_heads': 4,
                'seq_tfmr_num_layers': 2,

                'no_seq_transition_layers': 1,
                'seq_transition_dropout_rate': 0.1,

                'predict_torsion': True,
                'angle_c_resnet': 128,
                'angle_no_resnet_blocks': 2,
                'no_angles': 1,
                'reconstruct_backbone_atoms': True,
                'reconstruct_CB': False,
                'torsion_index': 0,

                'trans_scale_factor': 10,

                'epsilon': eps,
                'inf': 1e5,

                'ipa': {
                    'c_s': c_s,
                    'c_z': c_z,
                    'c_hidden': 256,
                    'no_heads': 8,
                    'no_qk_points': 8,
                    'no_v_points': 12,
                    'inf': 1e5,
                    'eps': eps,
                },
            },
            'Diffuser': {
                'Global':{
                    'T': T,
                    'trans_scale_factor': 4.,
                },
                'Cartesian': {
                    'alpha_1' : 0.99,
                    'alpha_T' : 0.93,
                    'T' : T,
                },
                'SO3': {
                    'cache_dir': None,
                    'suffix': '_log',
                    'schedule': 'linear',
                    'sigma_1' : 0.1,
                    'sigma_T' : 1.5,
                    'reverse_strategy': 'hat_and_noise',
                    'hat_and_noise':{
                        'noise_scale': (5, 5),
                        'noise_scale_schedule': 'linear',
                    },
                    'T' : T,
                }
            },
            'Aux_head': {
                'SC': {
                    'latent_dim': 32,
                    'ff_dim': 32,
                    'c_out': 1,
                    'dropout': 0.1,
                    'n_layers': 5,
                }
            },
        },
    }
)