from typing import Dict
import torch

from coarsebind_public.coarsebind.io_schema import OutputSchemaDisto
from coarsebind_public.coarsebind.io_transforms import (
    mask_1d_cropper,
    affinity_cropper_2,
    lite_output,
)


class BaseWriter:

    def __init__(self):
        pass

    def process(self, output_features: Dict[str, torch.Tensor], **kwargs):
        return NotImplementedError("process method not implemented")


class BasePairformerWriter(BaseWriter):

    def __init__(self, pw_distance_cutoff: float = 15.0):
        super().__init__()
        self.pw_distance_cutoff = pw_distance_cutoff

        boundaries = torch.linspace(2, 22.0, 63)
        lower = torch.tensor([1.0])
        upper = torch.tensor([22.0 + 5.0])
        exp_boundaries = torch.cat((lower, boundaries, upper))
        self.mid_points = (exp_boundaries[:-1] + exp_boundaries[1:]) / 2

    def process(
        self,
        output: Dict[str, torch.Tensor],
        batch_features: Dict[str, torch.Tensor],
        output_idx: int,
    ) -> OutputSchemaDisto:

        return self.base_process(
            output,
            batch_features,
            output_idx,
        )

    def base_process(
        self,
        output: Dict[str, torch.Tensor],
        batch_features: Dict[str, torch.Tensor],
        output_idx: int,
    ) -> OutputSchemaDisto:
        """
        Processes the model output for a single item in a batch.

        This function extracts relevant tensors from the model output, calculates
        normalized bin entropy, applies masks, and converts tensors to numpy arrays
        for storage in the IO schema.

        Args:
            output (Dict[str, torch.Tensor]): The output from the model for a batch.
            batch_features (Dict[str, torch.Tensor]): The batch features used for inference.
            output_idx (int): The index of the item to process within the batch.

        Returns:
            OutputSchemaDisto: A schema object containing the processed output data.
        """

        pw_distance_cutoff = self.pw_distance_cutoff
        bin_probs = torch.softmax(output["pdistogram"][output_idx], dim=-1)

        bin_entropy = -torch.sum(bin_probs * torch.log(bin_probs + 1e-8), dim=-1)
        max_entropy = -torch.log(torch.tensor(1.0 / bin_probs.shape[-1]) + 1e-8)
        norm_bin_entropy = bin_entropy / max_entropy

        pw_distances = (bin_probs * self.mid_points.to(bin_probs.device)).sum(dim=-1)

        valid_mask = batch_features["token_pad_mask"][output_idx]
        res_type = batch_features["res_type"][output_idx]
        potency_ligand_mask = batch_features["asym_id"][output_idx] == 0

        # save layer reps
        save_layer_reps: Dict[str, torch.Tensor] = {}
        save_layer_reps["z"] = output["z"][output_idx].float()
        save_s_reps: Dict[str, torch.Tensor] = {}
        save_s_reps["s_inputs"] = output["s_inputs"][output_idx].float()

        # set diagonal to 0
        false_diag = ~torch.eye(
            pw_distances.shape[0],
            pw_distances.shape[1],
            dtype=bool,
            device=pw_distances.device,
        )
        pw_distances = pw_distances * false_diag
        bin_probs = bin_probs * false_diag.unsqueeze(2)
        norm_bin_entropy = norm_bin_entropy * false_diag

        within_cutoff_pair_mask = (
            ((pw_distances[potency_ligand_mask] < pw_distance_cutoff))
            | (potency_ligand_mask[None, :] & valid_mask)
        ) & false_diag[potency_ligand_mask]
        # at least one interaction within cutoff
        within_cutoff_mask = within_cutoff_pair_mask.sum(0) > 0

        disto_output = OutputSchemaDisto(
            pw_distance_cutoff=pw_distance_cutoff,
            res_type=res_type,
            entity_id=batch_features["entity_id"][output_idx],
            sym_id=batch_features["sym_id"][output_idx],
            asym_id=batch_features["asym_id"][output_idx],
            res_num=batch_features["res_num"][output_idx],
            res_name=batch_features["res_name"][output_idx],
            potency_ligand_mask=potency_ligand_mask,
            bin_probs=bin_probs,
            norm_bin_entropy=norm_bin_entropy,
            pw_distances=pw_distances,
            save_layer_reps=save_layer_reps,
            save_s_reps=save_s_reps,
            within_cutoff_mask=within_cutoff_mask,
            within_cutoff_pair_mask=within_cutoff_pair_mask,
        )

        disto_output = mask_1d_cropper(disto_output, valid_mask)

        return disto_output


class PairformerWriter(BasePairformerWriter):
    def __init__(self, pw_distance_cutoff: float = 15.0):
        super().__init__(pw_distance_cutoff=pw_distance_cutoff)

    def process(
        self,
        output: Dict[str, torch.Tensor],
        batch_features: Dict[str, torch.Tensor],
        output_idx: int,
    ) -> OutputSchemaDisto:

        result = self.base_process(
            output,
            batch_features,
            output_idx,
        )

        result = mask_1d_cropper(result, result.within_cutoff_mask)

        return result


class LitePairformerWriter(BasePairformerWriter):
    def __init__(self, pw_distance_cutoff: float = 15.0):
        super().__init__(pw_distance_cutoff=pw_distance_cutoff)

    def process(
        self,
        output: Dict[str, torch.Tensor],
        batch_features: Dict[str, torch.Tensor],
        output_idx: int,
    ) -> OutputSchemaDisto:

        result = self.base_process(
            output,
            batch_features,
            output_idx,
        )

        result = mask_1d_cropper(result, result.within_cutoff_mask)

        result = lite_output(result)

        return result


class AffinityPairformerWriter(BasePairformerWriter):
    def __init__(self, pw_distance_cutoff: float = 15.0):
        super().__init__(pw_distance_cutoff=pw_distance_cutoff)

    def process(
        self,
        output: Dict[str, torch.Tensor],
        batch_features: Dict[str, torch.Tensor],
        output_idx: int,
    ) -> OutputSchemaDisto:

        result = self.base_process(
            output,
            batch_features,
            output_idx,
        )

        result = affinity_cropper_2(result)

        return result
