# Adapted from OpenFold
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
import torch
import torch.nn as nn


def permute_final_dims(tensor, *inds):
    zero_index = -1 * len(inds)
    first_inds = range(len(tensor.shape[:zero_index]))
    return tensor.permute(*first_inds, *[zero_index + i for i in inds])


def flatten_final_dims(tensor, no_dims):
    return tensor.reshape(*tensor.shape[:-no_dims], -1)


def masked_mean(mask, value, dim, eps=1e-10):
    mask = mask.expand(*value.shape)
    return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))


def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
    boundaries = torch.linspace(
        min_bin, max_bin, no_bins - 1, device=pts.device
    )
    dists = torch.sqrt(
        torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
    )
    return torch.bucketize(dists, boundaries)


def stack_tensor_dicts(dicts):
    first = dicts[0]
    new_dict = {}
    for k, v in first.items():
        all_v = [d[k] for d in dicts]
        if(type(v) is dict):
            new_dict[k] = stack_tensor_dicts(all_v)
        else:
            new_dict[k] = torch.stack(all_v)

    return new_dict


def one_hot(x, v_bins):
    reshaped_bins = v_bins.view(*((1,) * len(x.shape) + (len(v_bins),)))
    diffs = x[..., None] - reshaped_bins
    am = torch.argmin(torch.abs(diffs), dim=-1)
    return nn.functional.one_hot(am, num_classes=len(v_bins)).float()


def batched_gather(data, inds, dim=0, no_batch_dims=0):
    ranges = []
    for i, s in enumerate(data.shape[:no_batch_dims]):
        r = torch.arange(s)
        r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
        ranges.append(r)

    remaining_dims = [
        slice(None) for _ in range(len(data.shape) - no_batch_dims)
    ]
    remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
    ranges.extend(remaining_dims)
    return data[ranges]


# With tree_map, a poor man's JAX tree_map
def dict_map(fn, dic, leaf_type):
    new_dict = {}
    for k, v in dic.items():
        if(type(v) is dict):
            new_dict[k] = dict_map(fn, v, leaf_type)
        else:
            new_dict[k] = tree_map(fn, v, leaf_type)

    return new_dict


def tree_map(fn, tree, leaf_type):
    tree_type = type(tree)
    if(tree_type is dict):
        return dict_map(fn, tree, leaf_type)
    elif(tree_type is list):
        return [tree_map(fn, x, leaf_type) for x in tree]
    elif(tree_type is tuple):
        return tuple([tree_map(fn, x, leaf_type) for x in tree])
    elif(tree_type is leaf_type):
        return fn(tree)
    else:
        raise ValueError("Not supported")

tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)

def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
    """
        Implements the "chunking" procedure described in section 1.11.8.

        Layer outputs and inputs are interpreted as simplified "pytrees," 
        consisting only of (nested) lists, tuples, and dicts with tensor
        leaves.

        Args:
            layer:
                The layer to be applied chunk-wise
            inputs:
                A (nested) dictionary of keyworded inputs. All leaves must be 
                tensors and must share the same batch dimensions.
            chunk_size:
                The number of sub-batches per chunk. If multiple batch
                dimensions are specified, a "sub-batch" is defined as a single
                indexing of all batch dimensions simultaneously (s.t. the 
                number of sub-batches is the product of the batch dimensions).
            no_batch_dims:
                How many of the initial dimensions of each input tensor can
                be considered batch dimensions.
        Returns:
            The reassembled output of the layer on the inputs.
    """
        
    if(not (len(inputs) > 0)):
        raise ValueError("Must provide at least one input")

    def fetch_dims(tree):
        shapes = []
        tree_type = type(tree)
        if(tree_type is dict):
            for v in tree.values():
                shapes.extend(fetch_dims(v))
        elif(tree_type is list or tree_type is tuple):
            for t in tree:
                shapes.extend(fetch_dims(t))
        elif(tree_type is torch.Tensor):
            shapes.append(tree.shape)
        else:
            raise ValueError("Not supported")

        return shapes
 
    initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)]
    orig_batch_dims = [max(s) for s in zip(*initial_dims)]

    def prep_inputs(t):
        t = t.expand(*orig_batch_dims, *t.shape[no_batch_dims:])
        t = t.reshape(-1, *t.shape[no_batch_dims:])
        return t

    #shape = lambda t: t.shape
    #print(tensor_tree_map(shape, inputs))

    flattened_inputs = tensor_tree_map(prep_inputs, inputs)

    flat_batch_dim = 1
    for d in orig_batch_dims:
        flat_batch_dim *= d

    no_chunks = (
        flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
    )

    i = 0
    out = None
    for _ in range(no_chunks):
        # Chunk the input
        select_chunk = lambda t: t[i:i+chunk_size]
        chunks = tensor_tree_map(select_chunk, flattened_inputs)

        # Run the layer on the chunk
        output_chunk = layer(**chunks)

        # Allocate space for the output
        if(out is None):
            allocate = lambda t: t.new_zeros(flat_batch_dim, *t.shape[1:])
            out = tensor_tree_map(allocate, output_chunk)

        # Put the chunk in its pre-allocated space
        out_type = type(output_chunk)
        if(out_type is dict):
            def assign(d1, d2):
                for k,v in d1.items():
                    if(type(v) is dict):
                        assign(v, d2[k])
                    else:
                        v[i:i+chunk_size] = d2[k]
            assign(out, output_chunk)
        elif(out_type is tuple):
            for x1, x2 in zip(out, output_chunk):
                x1[i:i+chunk_size] = x2
        elif(out_type is torch.Tensor):
            out[i:i+chunk_size] = output_chunk
        else:
            raise ValueError("Not supported")

        i += chunk_size

    reshape = lambda t: t.reshape(*orig_batch_dims, *t.shape[1:])
    out = tensor_tree_map(reshape, out)

    return out    
