# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

import torch

from data import data_transforms


def nonensembled_transform_fns(common_cfg, mode_cfg):
    """Input pipeline data transformers that are not ensembled."""
    transforms = [
        # add feat: "token_type" = numerical token type
        data_transforms.cast_to_64bit_ints,
        # no extra feat
        data_transforms.squeeze_features,
        # add feat: "seq_mask" = torch.ones(NUM_RES)
        data_transforms.make_seq_mask,
        # add feats:
        #   "atom14_atom_exists"
        #   "residx_atom14_to_atomFull"
        #   "residx_atomFull_to_atom14"
        #   "atomFull_atom_exists"
        data_transforms.make_atom14_masks,
        # make ligand center as global center
        data_transforms.centering_positions,
    ]

    if mode_cfg.supervised:
        # these are all supervised features, not for model input
        transforms.extend(
            [
                # add feats:
                #   "atom14_atom_exists"
                #   "atom14_gt_exists"
                #   "atom14_gt_positions"
                #   "atom14_alt_gt_positions"
                #   "atom14_alt_gt_exists"
                #   "atom14_atom_is_ambiguous"
                data_transforms.make_atom14_positions,
                # add feats:
                #   "rigidgroups_gt_frames"
                #   "rigidgroups_gt_exists"
                #   "rigidgroups_group_exists"
                #   "rigidgroups_group_is_ambiguous"
                #   "rigidgroups_alt_gt_frames"
                data_transforms.atomFull_to_frames,
                # add feats:
                #   "torsion_angles_sin_cos"
                #   "alt_torsion_angles_sin_cos"
                #   "torsion_angles_mask"
                data_transforms.atomFull_to_torsion_angles(""),
                # add feats:
                #   "pseudo_beta"
                #   "pseudo_beta_mask"
                data_transforms.make_pseudo_beta(""),
                # add feats:
                #   "backbone_rigid_tensor"
                #   "backbone_rigid_mask"
                data_transforms.get_backbone_frames,
            ]
        )

    return transforms


def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
    """Input pipeline data transformers that can be ensembled and averaged."""
    transforms = []

    crop_feats = dict(common_cfg.feat)

    transforms.append(
        data_transforms.make_fixed_size(
            crop_feats,
            num_res=mode_cfg.crop_size
        )
    )
    return transforms


def process_tensors_from_config(tensors, common_cfg, mode_cfg):
    """Based on the config, apply filters and transformations to the data."""

    ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)

    def wrap_ensemble_fn(data, i):
        """Function to be mapped over the ensemble dimension."""
        d = data.copy()
        fns = ensembled_transform_fns(
            common_cfg, 
            mode_cfg, 
            ensemble_seed,
        )
        fn = compose(fns)
        return fn(d)

    nonensembled = nonensembled_transform_fns(
        common_cfg,
        mode_cfg,
    )

    tensors = compose(nonensembled)(tensors)
    
    # tensors = map_fn(
    #     lambda x: wrap_ensemble_fn(tensors, x), torch.arange(1)
    # )

    return tensors


@data_transforms.curry1
def compose(x, fs):
    for f in fs:
        x = f(x)
    return x


def map_fn(fun, x):
    ensembles = [fun(elem) for elem in x]
    features = ensembles[0].keys()
    ensembled_dict = {}
    for feat in features:
        ensembled_dict[feat] = torch.stack(
            [dict_i[feat] for dict_i in ensembles], dim=-1
        )
    return ensembled_dict
