from typing import List, Dict
import torch
import numpy as np
from torch.nn.functional import pad

from coarsebind_public.coarsebind.io_schema import IOSchemaCoarseBind, OutputSchemaDisto


def pad_to_max(data: list[torch.Tensor], value: float = 0) -> tuple[torch.Tensor, torch.Tensor]:
    """Pad the data in all dimensions to the maximum found.

    Parameters
    ----------
    data : list[Tensor]
        list of tensors to pad.
    value : float
        The value to use for padding.

    Returns
    -------
    Tensor
        The padded tensor.
    Tensor
        The padding mask.

    """
    # if isinstance(data[0], str):
    #    return data, 0
    #
    ## Check if all have the same shape
    # if all(d.shape == data[0].shape for d in data):
    #    return torch.stack(data, dim=0), 0

    # Get the maximum in each dimension
    num_dims = len(data[0].shape)
    max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)]

    # Get the padding lengths
    pad_lengths = []
    for d in data:
        dims = []
        for i in range(num_dims):
            dims.append(0)
            dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1])
        pad_lengths.append(dims)

    # Pad the data
    padding = [pad(torch.ones_like(d), pad_len, value=0) for d, pad_len in zip(data, pad_lengths)]
    data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)]

    # Stack the data
    padding = torch.stack(padding, dim=0)
    data = torch.stack(data, dim=0)

    return data, padding


def global_potency_feats_collate_fn(
    batch: List[IOSchemaCoarseBind],
) -> Dict[str, torch.Tensor]:

    esm2_eos_embed = []
    mol_enc = []
    assay_embeds = []

    for row in batch:

        if row.mol_enc_io is not None:
            mol_enc.append(torch.from_numpy(row.mol_enc_io.smiles_embed))

    if len(mol_enc) > 0:
        mol_enc = torch.stack(mol_enc, dim=0)
    else:
        mol_enc = None

    features = {
        "mol_enc": mol_enc,
    }
    return features


def disto_potency_feats_collate_fn_2(
    batch: List[OutputSchemaDisto],
) -> Dict[str, torch.Tensor]:

    z, z_mask = pad_to_max(
        [row.save_layer_reps["z"] for row in batch],
        value=0,
    )

    z_mask = z_mask.any(dim=-1)

    s_inputs, _ = pad_to_max(
        [row.save_s_reps["s_inputs"] for row in batch],
        value=0,
    )

    bin_probs, _ = pad_to_max([row.bin_probs for row in batch], value=0)

    bin_entropies, _ = pad_to_max(
        [row.norm_bin_entropy for row in batch],
        value=0,
    )

    pw_distances, _ = pad_to_max([row.pw_distances for row in batch], value=0)

    potency_ligand_mask, valid_mask = pad_to_max(
        [row.potency_ligand_mask.bool() for row in batch],
        value=False,
    )

    features = {
        "z": z,
        "z_mask": z_mask,
        "s_inputs": s_inputs,
        "bin_probs": bin_probs,
        "bin_entropies": bin_entropies,
        "pw_distances": pw_distances,
        "potency_ligand_mask": potency_ligand_mask,
        "valid_mask": valid_mask,
    }

    return features
