import logging
from typing import List

import torch
import torch.nn.functional as F
from smart_open import open

from coarsebind_public.coarsebind.pdb_utils import (
    frames_to_pdb,
    assign_pdb_template,
)
from coarsebind_public.coarsebind.io_schema import IOSchemaCoarseBind
from coarsebind_public.coarsebind.io_transforms import mask_1d_cropper

logger = logging.getLogger(__name__)


def _collate_for_coarse_cofold(io_schema_list: List[IOSchemaCoarseBind], device: str):

    # Collate and pad
    batch_size = len(io_schema_list)

    max_n_cutoff = max([len(io.disto_output.res_type) for io in io_schema_list])
    potency_ligand_mask = torch.zeros(
        (batch_size, max_n_cutoff),
        dtype=torch.bool,
    )
    bin_probs = torch.zeros(
        (
            batch_size,
            max_n_cutoff,
            max_n_cutoff,
            io_schema_list[0].disto_output.bin_probs.shape[-1],
        ),
        dtype=torch.float32,
    )
    valid_mask = torch.zeros(
        (batch_size, max_n_cutoff),
        dtype=torch.bool,
    )
    pw_distances = torch.zeros(
        (batch_size, max_n_cutoff, max_n_cutoff),
        dtype=torch.float32,
    )

    template_coords = torch.zeros(
        (batch_size, max_n_cutoff, 3),
        dtype=torch.float32,
    )

    for i, io in enumerate(io_schema_list):

        _system_size = len(io.disto_output.res_type)
        potency_ligand_mask[i, :_system_size] = torch.from_numpy(
            io.disto_output.potency_ligand_mask
        )
        bin_probs[i, :_system_size, :_system_size] = torch.from_numpy(io.disto_output.bin_probs)
        valid_mask[i, :_system_size] = True
        pw_distances[i, :_system_size, :_system_size] = torch.from_numpy(
            io.disto_output.pw_distances
        )
        if io.disto_output.coarse_cofold_template_coords is not None:
            template_coords[i, :_system_size] = torch.from_numpy(
                io.disto_output.coarse_cofold_template_coords
            )

    pair_mask = valid_mask[:, None, :] * valid_mask[:, :, None]
    # make diagonal 0
    false_diag = ~torch.eye(max_n_cutoff, dtype=torch.bool)[None, :, :]
    pair_mask = pair_mask & false_diag

    output = {
        "potency_ligand_mask": potency_ligand_mask.to(device),
        "bin_probs": bin_probs.to(device),
        "pair_mask": pair_mask.to(device),
        "valid_mask": valid_mask.to(device),
        "pw_distances": pw_distances.to(device),
        "template_coords": template_coords.to(device),
    }

    return output


def calculate_loss_l2(
    coords_padded: torch.Tensor,
    ref_pw_distances: torch.Tensor,
    upweight_tensor: torch.Tensor,
    pair_mask: torch.Tensor,
    gamma: float,
) -> torch.Tensor:
    """A compiled function to calculate the masked KL-divergence loss."""
    # Calculate all pairwise distances in the padded batch
    dists_matrix = torch.cdist(coords_padded, coords_padded, p=2.0)

    # Masked Loss Calculation

    per_pair_loss = (dists_matrix - ref_pw_distances) ** 2
    masked_loss = per_pair_loss * pair_mask * upweight_tensor

    num_real_pairs = (pair_mask).sum()
    loss = masked_loss.sum() / (num_real_pairs + 1e-8)
    return loss


class CoarseCofoldInf(object):
    def __init__(
        self,
        device: str,
        batch_size: int = 32,
    ):
        self.device = device
        self.batch_size = batch_size

    def predict(self, io_schema: List[IOSchemaCoarseBind]) -> List[IOSchemaCoarseBind]:
        """
        Predict the coarse cofold trajectory given the input schema.
        """

        for i, _io in enumerate(io_schema):
            if _io.error:
                continue

            if _io.disto_output.template_path and _io.disto_output.chain_id:

                _io = assign_pdb_template(_io)

                _io = mask_1d_cropper(
                    _io,
                    _io.template_mask,
                )

            _io.disto_output = mask_1d_cropper(
                _io.disto_output, _io.disto_output.within_cutoff_mask
            )

        valid_entries = [(i, entry) for i, entry in enumerate(io_schema) if not entry.error]
        if not valid_entries:
            return io_schema

        valid_indices, valid_io_schema = zip(*valid_entries)

        for i in range(0, len(valid_io_schema), self.batch_size):
            batch_indices: List[int] = valid_indices[i : i + self.batch_size]
            batch_io_schema = valid_io_schema[i : i + self.batch_size]

            coords_trajectories = self.coarse_cofold_template(list(batch_io_schema))

            for j, traj in zip(batch_indices, coords_trajectories):

                (
                    protein_pdb_str,
                    ligand_pdb_str,
                ) = frames_to_pdb(
                    io_schema=io_schema[j],
                    # use only the last frame
                    coords=traj[-1][None, :, :],
                )
                io_schema[j].disto_output.coarse_cofold_protein_pdb_str = protein_pdb_str
                io_schema[j].disto_output.coarse_cofold_ligand_pdb_str = ligand_pdb_str

        return io_schema

    def coarse_cofold_template(
        self,
        io_schema_list: List[IOSchemaCoarseBind],
        patience=20,
        tol=1e-5,
        max_iters=5000,
        lig_lig_weight=10.0,
        gamma=0.1,
    ):

        batch_size = len(io_schema_list)
        batch = _collate_for_coarse_cofold(io_schema_list, self.device)

        potency_ligand_mask = batch["potency_ligand_mask"]
        bin_probs = batch["bin_probs"]
        pair_mask = batch["pair_mask"]
        valid_mask = batch["valid_mask"]
        pw_distances = batch["pw_distances"]
        template_coords = batch["template_coords"]

        # wont optimize the template coords
        is_template_coord = (template_coords.sum(-1) != 0.0).unsqueeze(-1)

        # upweight lig-lig pairs
        potency_ligand_pair_mask = potency_ligand_mask[:, None, :] * potency_ligand_mask[:, :, None]
        lig_lig_upweight = torch.ones_like(potency_ligand_pair_mask, dtype=torch.float32)
        lig_lig_upweight = (
            lig_lig_upweight + (lig_lig_weight - 1) * potency_ligand_pair_mask.float()
        )
        lig_lig_upweight = lig_lig_upweight * pair_mask.float()

        max_n_cutoff = potency_ligand_mask.shape[1]
        init_coords = torch.randn((batch_size, max_n_cutoff, 3), device=self.device)
        init_coords = init_coords * (~is_template_coord) + template_coords

        coords = init_coords.clone().detach().requires_grad_(True)
        optimizer = torch.optim.Adam([coords], lr=1)

        trajectories = [[] for _ in range(batch_size)]
        prev_loss = torch.full((batch_size,), float("inf"), device=self.device)
        stable_steps = torch.zeros(batch_size, dtype=torch.int, device=self.device)

        for step in range(max_iters):
            optimizer.zero_grad()

            loss = calculate_loss_l2(
                coords,
                ref_pw_distances=pw_distances,
                upweight_tensor=lig_lig_upweight,
                pair_mask=pair_mask,
                gamma=gamma,
            )

            delta_loss = torch.abs(prev_loss - loss)
            stable_steps = torch.where(
                delta_loss < tol, stable_steps + 1, torch.zeros_like(stable_steps)
            )

            if (stable_steps >= patience).all():
                logger.info(
                    f"[Step {step:02d}] Loss: {loss.item():.4f} All converged (Δloss < {tol}) for {patience} steps, stopping."
                )
                break

            prev_loss = loss.detach()

            if step % 25 == 0:
                logger.info(f"[Step {step:02d}] Loss: {loss.item():.4f}")
            loss.backward()

            coords.grad = coords.grad * (~is_template_coord)

            optimizer.step()

            for i in range(batch_size):
                # Store the trajectory
                _coords = coords[i][valid_mask[i]]
                _coords = _coords.detach().cpu().numpy()

                trajectories[i].append(_coords)

        return trajectories
