from typing import Dict, List, Callable, Tuple

import torch
import numpy as np

from coarsebind_public.coarsebind.io_schema import (
    IOSchemaCoarseBind,
    OutputSchemaDisto,
)


def to_numpy(io: IOSchemaCoarseBind) -> IOSchemaCoarseBind:
    """
    Converts all numpy arrays in the IOSchema to numpy arrays.
    """

    for attr in io.__dataclass_fields__.keys():
        val = getattr(io, attr)
        if hasattr(val, "detach") and hasattr(val, "cpu") and hasattr(val, "numpy"):
            setattr(io, attr, val.detach().cpu().numpy())
        # work on nested dataclasses
        elif hasattr(val, "__dataclass_fields__"):
            setattr(io, attr, to_numpy(val))
        # work on nested dictionaries
        elif isinstance(val, dict):
            for k, v in val.items():
                if hasattr(v, "detach") and hasattr(v, "cpu") and hasattr(v, "numpy"):
                    val[k] = v.detach().cpu().numpy()
            setattr(io, attr, val)

    return io


def flatten_dict(d: Dict, sep: str = ".") -> Dict:
    """
    Flattens a nested dictionary into a single-level dictionary with keys joined by a separator.

    Args:
        d (Dict): The dictionary to flatten.
        sep (str, optional): Separator to use when joining nested keys. Defaults to '_'.

    Returns:
        Dict: A new flattened dictionary.
    """
    flat_dict = {}

    for key, value in d.items():
        if isinstance(value, dict):
            nested_flat_dict = flatten_dict(value, sep=sep)
            for nested_key, nested_value in nested_flat_dict.items():
                flat_dict[f"{key}{sep}{nested_key}"] = nested_value
        else:
            flat_dict[key] = value

    return flat_dict


def mask_1d_cropper(
    io: OutputSchemaDisto,
    mask: np.ndarray,
) -> OutputSchemaDisto:

    # cropped features
    io.res_type = io.res_type[mask]
    io.entity_id = io.entity_id[mask]
    io.sym_id = io.sym_id[mask]
    io.asym_id = io.asym_id[mask]
    io.res_num = io.res_num[mask]
    io.potency_ligand_mask = io.potency_ligand_mask[mask]
    io.bin_probs = io.bin_probs[mask][:, mask]
    io.norm_bin_entropy = io.norm_bin_entropy[mask][:, mask]
    io.pw_distances = io.pw_distances[mask][:, mask]
    io.within_cutoff_mask = io.within_cutoff_mask[mask]
    io.within_cutoff_pair_mask = io.within_cutoff_pair_mask[:, mask]

    if io.res_name is None:
        pass
    elif isinstance(mask, torch.Tensor):
        io.res_name = io.res_name[mask.cpu().numpy()]
    else:
        io.res_name = io.res_name[mask]

    if io.save_s_reps:
        io.save_s_reps = {k: v[mask] for k, v in io.save_s_reps.items()}

    if io.save_layer_reps:
        io.save_layer_reps = {k: v[mask][:, mask] for k, v in io.save_layer_reps.items()}

    if (
        io.coarse_cofold_template_coords is not None
        and io.template_res_idxs is not None
        and io.template_mask is not None
    ):
        # crop the template coords
        io.coarse_cofold_template_coords = io.coarse_cofold_template_coords[mask]
        io.template_res_idxs = io.template_res_idxs[mask]
        io.template_mask = io.template_mask[mask]

    return io


def lite_output(
    io: OutputSchemaDisto,
) -> OutputSchemaDisto:

    # set heavy features to None
    io.bin_probs = None
    io.norm_bin_entropy = None
    io.save_s_reps = None
    io.save_layer_reps = None

    return io


def affinity_cropper(
    io: OutputSchemaDisto,
) -> IOSchemaCoarseBind:

    # truncate
    io = mask_1d_cropper(io, io.within_cutoff_mask)

    # TODO.. this transpose stuff should be handled upstream (but model was not trained this way)
    io.bin_probs = ((io.bin_probs + io.bin_probs.transpose(0, 1)) / 2)[io.potency_ligand_mask]
    io.norm_bin_entropy = ((io.norm_bin_entropy + io.norm_bin_entropy.transpose(0, 1)) / 2)[
        io.potency_ligand_mask
    ]
    io.pw_distances = ((io.pw_distances + io.pw_distances.transpose(0, 1)) / 2)[
        io.potency_ligand_mask
    ]
    for k, v in io.save_layer_reps.items():
        v = (v + v.transpose(0, 1)) / 2
        io.save_layer_reps[k] = v[io.potency_ligand_mask]

    return io


def affinity_cropper_2(
    io: OutputSchemaDisto,
) -> IOSchemaCoarseBind:

    # truncate to residues within cutoff only
    io = mask_1d_cropper(io, io.within_cutoff_mask)

    lig_lig_mask = io.potency_ligand_mask[:, None] & io.potency_ligand_mask[None, :]
    res_lig_mask = (~io.potency_ligand_mask[:, None] & io.potency_ligand_mask[None, :]) + (
        io.potency_ligand_mask[:, None] & ~io.potency_ligand_mask[None, :]
    )

    sparse_mask = lig_lig_mask | res_lig_mask

    io.bin_probs = io.bin_probs[sparse_mask]

    io.norm_bin_entropy = io.norm_bin_entropy[sparse_mask]
    io.pw_distances = io.pw_distances[sparse_mask]
    for k, v in io.save_layer_reps.items():
        io.save_layer_reps[k] = v[sparse_mask]

    return io


def bins_to_uniform(
    orig_bin_centers: torch.Tensor, pairformer_output: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Converts the binned distance predictions from the pairformer to uniform distance predictions.

    Args:
        orig_bin_centers (torch.Tensor): The original bin centers used in the pairformer model.
        pairformer_output (Dict[str, torch.Tensor]): The output from the pairformer model.

    Returns:
        torch.Tensor: The uniform bin entropy.
        torch.Tensor: The uniform bin probabilities.
    """
    bin_probs = torch.softmax(pairformer_output["logits"], dim=-1)

    uniform_edges = torch.linspace(2, 22, 64 - 1).to(orig_bin_centers.device)
    insertion_idxs = (torch.searchsorted(uniform_edges, orig_bin_centers) - 1).clamp(min=0)
    insertion_idxs = insertion_idxs.view(1, 1, 1, -1)
    uniform_bin_probs = torch.zeros_like(bin_probs)
    uniform_bin_probs = uniform_bin_probs.scatter_add(
        -1, insertion_idxs.expand_as(uniform_bin_probs), bin_probs
    )

    uniform_bin_entropy = -torch.sum(
        uniform_bin_probs * torch.log(uniform_bin_probs + 1e-8), dim=-1
    )
    uniform_max_entropy = -torch.log(torch.tensor(1.0 / uniform_bin_probs.shape[-1]) + 1e-8)
    uniform_bin_entropy = uniform_bin_entropy / uniform_max_entropy

    return uniform_bin_entropy, uniform_bin_probs
