import itertools
import random
import re
from typing import List, Literal, Tuple

import biotite.structure.io as strucio
import numpy as np
import pandas as pd
import torch
from loguru import logger

from openfold.np.residue_constants import (
    atom_order,
    atom_types,
    restype_3to1,
    restype_num,
    restype_order,
)
from proteinfoundation.utils.constants import SIDECHAIN_TIP_ATOMS
from proteinfoundation.utils.coors_utils import ang_to_nm
from proteinfoundation.utils.align_utils import mean_w_mask


def _select_motif_atoms(
    available_atoms: List[int],
    atom_selection_mode: Literal["ca", "bb3o", "all_atom", "tip_atoms"] = "ca",
    residue_name: str = None,
) -> List[int]:

    backbone_atoms = [0, 1, 2, 4]
    ca_index = 1

    if atom_selection_mode == "ca":

        return [ca_index] if ca_index in available_atoms else []

    elif atom_selection_mode == "bb3o":

        return [i for i in backbone_atoms if i in available_atoms]

    elif atom_selection_mode == "all_atom":

        return available_atoms

    elif atom_selection_mode == "tip_atoms":

        if residue_name is None:
            raise ValueError("residue_name must be provided for tip_atoms mode")

        tip_atom_names = SIDECHAIN_TIP_ATOMS.get(residue_name, [])
        tip_atom_indices = []
        for atom_name in tip_atom_names:
            if atom_name in atom_order:
                atom_idx = atom_order[atom_name]
                if atom_idx in available_atoms:
                    tip_atom_indices.append(atom_idx)
        return tip_atom_indices

    else:
        raise ValueError(
            f"Unknown atom selection mode: {atom_selection_mode}. Supported modes: ca, bb3o, all_atom, tip_atoms"
        )


def generate_combinations(min_cost, max_cost, ranges):
    result = []
    ranges = [[x] if isinstance(x, int) else range(x[0], x[1] + 1) for x in ranges]
    for combination in itertools.product(*ranges):
        total_cost = sum(combination)
        if min_cost <= total_cost <= max_cost:
            padded_combination = list(combination) + [0] * (
                len(ranges) - len(combination)
            )
            result.append(padded_combination)
    return result


def generate_motif_indices(
    contig: str,
    min_length: int,
    max_length: int,
    nsamples: int = 1,
) -> Tuple[List[int], List[List[int]], List[str]]:

    ALPHABET = "ABCDEFGHJKLMNOPQRSTUVWXYZ"
    components = contig.split("/")
    ranges = []
    motif_length = 0
    for part in components:
        if part[0] in ALPHABET:

            if "-" in part:
                start, end = map(int, part[1:].split("-"))
            else:
                start = end = int(part[1:])
            length = end - start + 1
            motif_length += length
        else:

            if "-" in part:
                bounds = part.split("-")
                assert int(bounds[0]) <= int(bounds[-1])
                ranges.append((int(bounds[0]), int(bounds[-1])))
            else:
                length = int(part)
                ranges.append(length)
    combinations = generate_combinations(
        min_length - motif_length, max_length - motif_length, ranges
    )
    if len(combinations) == 0:
        raise ValueError(
            "No Motif combinations to sample from please update the max and min lengths"
        )

    overall_lengths = []
    motif_indices = []
    out_strs = []
    combos = random.choices(combinations, k=nsamples)
    for combo in combos:
        combo_idx = 0
        current_position = 1
        motif_index = []
        output_string = ""
        for part in components:
            if part[0] in ALPHABET:

                if "-" in part:
                    start, end = map(int, part[1:].split("-"))
                else:
                    start = end = int(part[1:])
                length = end - start + 1
                motif_index.extend(range(current_position, current_position + length))
                new_part = part[0] + str(current_position)
                if length > 1:
                    new_part += "-" + str(current_position + length - 1)
                output_string += new_part + "/"
            else:

                length = int(combo[combo_idx])
                combo_idx += 1
                output_string += str(length) + "/"
            current_position += length
        overall_lengths.append(current_position - 1)
        motif_indices.append(motif_index)
        out_strs.append(output_string[:-1])
    return (overall_lengths, motif_indices, out_strs)


def parse_motif_atom_spec(spec: str):

    motif_atoms = []
    for match in re.finditer(r"([A-Za-z])(\d+): \[([^\]]+)\]", spec):
        chain = match.group(1)
        res_id = int(match.group(2))
        atoms = [a.strip() for a in match.group(3).split(",")]
        motif_atoms.append((chain, res_id, atoms))
    return motif_atoms


def extract_motif_atoms_from_pdb(
    pdb_path: str,
    motif_atom_spec: str,
):

    array = strucio.load_structure(pdb_path, model=1)
    motif_atoms = parse_motif_atom_spec(motif_atom_spec)
    mask = np.zeros(len(array), dtype=bool)
    for chain, res_id, atom_names in motif_atoms:
        mask |= (
            (array.chain_id == chain)
            & (array.res_id == res_id)
            & np.isin(array.atom_name, atom_names)
        )
    return array[mask]


def extract_motif_from_pdb(
    position: str,
    pdb_path: str,
    motif_only: bool = False,
    motif_atom_spec: str = None,
    atom_selection_mode: Literal["ca", "bb3o", "all_atom", "tip_atoms"] = "ca",
    coors_to_nm: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    if motif_atom_spec is not None:
        logger.info(f"Using atom-level motif specification: {motif_atom_spec[:100]}...")
        array = strucio.load_structure(pdb_path, model=1)
        motif_atoms = parse_motif_atom_spec(motif_atom_spec)

        unique_residues = []
        seen = set()
        for chain, res_id, _ in motif_atoms:
            if (chain, res_id) not in seen:
                seen.add((chain, res_id))
                unique_residues.append((chain, res_id))
        n_res = len(unique_residues)
        motif_mask = torch.zeros((n_res, 37), dtype=torch.bool)
        x_motif = torch.zeros((n_res, 37, 3), dtype=torch.float)
        residue_type = torch.ones((n_res), dtype=torch.int64) * restype_num
        for i, (chain_id, res_id) in enumerate(unique_residues):

            atom_names = []
            for c, r, names in motif_atoms:
                if c == chain_id and r == res_id:
                    atom_names.extend(names)

            res_mask = (array.chain_id == chain_id) & (array.res_id == res_id)
            res_atoms = array[res_mask]
            if len(res_atoms) == 0:
                continue
            res_type = restype_3to1.get(res_atoms[0].res_name, "UNK")
            residue_type[i] = restype_order.get(res_type, restype_num)
            for atom in res_atoms:
                if atom.atom_name in atom_names and atom.atom_name in atom_order:
                    atom37_idx = atom_order[atom.atom_name]
                    motif_mask[i, atom37_idx] = True
                    if coors_to_nm:
                        x_motif[i, atom37_idx] = ang_to_nm(torch.as_tensor(atom.coord))
                    else:
                        x_motif[i, atom37_idx] = torch.as_tensor(atom.coord)
        return motif_mask, x_motif, residue_type
    else:

        position = position.split("/")
        ALPHABET = "ABCDEFGHJKLMNOPQRSTUVWXYZ"
        array = strucio.load_structure(pdb_path, model=1)
        motif_array = []
        seen = set()
        for i in position:
            chain_id = i[0]
            if chain_id not in ALPHABET:
                continue
            atom_mask = (array.chain_id == chain_id) & (array.hetero == False)
            if motif_only:
                if chain_id in seen:
                    continue
                else:
                    seen.add(chain_id)
            else:
                i = i.replace(chain_id, "")
                if "-" not in i:
                    start = end = int(i)
                else:
                    start, end = i.split("-")
                    start, end = int(start), int(end)
                atom_mask = atom_mask & (array.res_id <= end) & (array.res_id >= start)
            motif_array.append(array[atom_mask])
        motif = motif_array[0]
        for i in range(len(motif_array) - 1):
            motif += motif_array[i + 1]

        seen = set()
        unique_residues = []
        for chain, resid in zip(motif.chain_id, motif.res_id):
            if (chain, resid) not in seen:
                seen.add((chain, resid))
                unique_residues.append((chain, resid))
        n_res = len(unique_residues)

        motif_mask = torch.zeros((n_res, 37), dtype=torch.bool)
        x_motif = torch.zeros((n_res, 37, 3), dtype=torch.float)
        residue_type = torch.ones((n_res), dtype=torch.int64) * restype_num

        for i, (chain_id, res_id) in enumerate(unique_residues):

            res_mask = (motif.chain_id == chain_id) & (motif.res_id == res_id)
            res_atoms = motif[res_mask]
            res_type = restype_3to1.get(res_atoms[0].res_name, "UNK")
            residue_type[i] = restype_order.get(res_type, restype_num)

            available_atom_indices = []
            for atom in res_atoms:
                if atom.atom_name in atom_order:
                    atom37_idx = atom_order[atom.atom_name]
                    available_atom_indices.append(atom37_idx)

            if len(available_atom_indices) > 0:
                selected_atom_indices = _select_motif_atoms(
                    available_atom_indices, atom_selection_mode, res_atoms[0].res_name
                )

                for atom in res_atoms:
                    if atom.atom_name in atom_order:
                        atom37_idx = atom_order[atom.atom_name]
                        if atom37_idx in selected_atom_indices:
                            motif_mask[i, atom37_idx] = True
                            if coors_to_nm:
                                x_motif[i, atom37_idx] = ang_to_nm(
                                    torch.as_tensor(atom.coord)
                                )
                            else:
                                x_motif[i, atom37_idx] = torch.as_tensor(atom.coord)

        motif_center = mean_w_mask(
            x_motif.flatten(0, 1), motif_mask.flatten(0, 1)
        ).unsqueeze(0)
        x_motif = x_motif - motif_center
        x_motif = x_motif * motif_mask[..., None]
        return motif_mask, x_motif, residue_type


def pad_motif_to_full_length(
    motif_mask: torch.Tensor,
    x_motif: torch.Tensor,
    residue_type: torch.Tensor,
    contig_string: str,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    ALPHABET = "ABCDEFGHJKLMNOPQRSTUVWXYZ"
    components = contig_string.split("/")
    current_position = 1
    motif_index = []
    for part in components:
        if part[0] in ALPHABET:

            if "-" in part:
                start, end = map(int, part[1:].split("-"))
            else:
                start = end = int(part[1:])
            length = end - start + 1
            motif_index.extend(range(current_position, current_position + length))
        else:

            length = int(part)
        current_position += length

    actual_length = current_position - 1
    motif_index = torch.tensor(motif_index, dtype=torch.int64) - 1
    motif_mask_full = torch.zeros((actual_length, 37), dtype=torch.bool)
    x_motif_full = torch.zeros((actual_length, 37, 3), dtype=torch.float)
    residue_type_full = torch.ones((actual_length,), dtype=torch.int64) * restype_num
    motif_mask_full[motif_index] = motif_mask
    x_motif_full[motif_index] = x_motif
    residue_type_full[motif_index] = residue_type
    return motif_mask_full, x_motif_full, residue_type_full


def pad_motif_to_full_length_unindexed(
    motif_mask: torch.Tensor,
    x_motif: torch.Tensor,
    residue_type: torch.Tensor,
    gen_coors: torch.Tensor,
    gen_mask: torch.Tensor,
    gen_aa_type: torch.Tensor,
    match_aatype: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    nres = gen_coors.shape[0]
    nres_motif = x_motif.shape[0]
    motif_index = []

    for i in range(nres_motif):

        motif_mask_i = motif_mask[i]
        x_motif_i = x_motif[i]
        aatype_motif_i = residue_type[i]

        best_match_idx = None
        best_rmsd = float("inf")

        for j in range(nres):

            gen_mask_j = gen_mask[j]
            gen_coors_j = gen_coors[j]
            aatype_gen_j = gen_aa_type[j]

            mask_motif_i_gen_j = motif_mask_i & gen_mask_j

            if mask_motif_i_gen_j.sum() == 0:

                continue

            x_motif_i_subset = x_motif_i[mask_motif_i_gen_j]
            gen_coors_j_subset = gen_coors_j[mask_motif_i_gen_j]

            rmsd = torch.sqrt(
                torch.sum((x_motif_i_subset - gen_coors_j_subset) ** 2, dim=1).mean()
            )

            cond = rmsd < best_rmsd and j not in motif_index
            if match_aatype:
                cond = cond and aatype_motif_i == aatype_gen_j

            if cond:
                best_rmsd = rmsd
                best_match_idx = j

        if best_match_idx is None:
            logger.warning(
                f"No best match found for motif component {i} with match_aatype={match_aatype}"
            )

        motif_index.append(best_match_idx)

    if None in motif_index:

        motif_index = [i for i in range(nres_motif)]
        logger.warning(
            "\n\n\nError during matching, defaulting to the first n residues\n\n\n"
        )

    motif_mask_full = torch.zeros((nres, 37), dtype=torch.bool)
    x_motif_full = torch.zeros((nres, 37, 3), dtype=torch.float)
    residue_type_full = torch.ones((nres,), dtype=torch.int64) * restype_num
    motif_mask_full[motif_index] = motif_mask
    x_motif_full[motif_index] = x_motif
    residue_type_full[motif_index] = residue_type
    return motif_mask_full, x_motif_full, residue_type_full


def parse_motif(
    motif_pdb_path: str,
    contig_string: str = None,
    nsamples: int = 1,
    motif_only: bool = False,
    motif_min_length: int = None,
    motif_max_length: int = None,
    segment_order: str = None,
    motif_atom_spec: str = None,
    atom_selection_mode: Literal["ca", "bb3o", "all_atom", "tip_atoms"] = "ca",
) -> Tuple[
    List[int], List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], List[str]
]:

    if motif_atom_spec is not None:
        logger.info(f"Using atom-level motif specification: {motif_atom_spec[:100]}...")
        motif_mask, x_motif, residue_type = extract_motif_from_pdb(
            None, motif_pdb_path, motif_atom_spec=motif_atom_spec
        )
        n_res = motif_mask.shape[0]

        return [n_res], [motif_mask], [x_motif], [residue_type], [None]

    valid_modes = ["ca", "bb3o", "all_atom", "tip_atoms"]
    if atom_selection_mode not in valid_modes:
        raise ValueError(
            f"Invalid atom_selection_mode '{atom_selection_mode}'. "
            f"Must be one of: {valid_modes}"
        )

    logger.info(
        f"Using residue/range-based motif specification with atom_selection_mode='{atom_selection_mode}'"
    )
    if contig_string:
        logger.info(f"Contig string: {contig_string}")

    motif_mask, x_motif, residue_type = extract_motif_from_pdb(
        contig_string,
        motif_pdb_path,
        motif_only=motif_only,
        atom_selection_mode=atom_selection_mode,
    )

    lengths, motif_indices, out_strs = generate_motif_indices(
        contig_string, motif_min_length, motif_max_length, nsamples
    )
    motif_masks = []
    x_motifs = []
    residue_types = []

    for length, motif_index, _ in zip(lengths, motif_indices, out_strs):

        cur_mask = torch.zeros((length, 37), dtype=torch.bool)
        assert (
            len(motif_index) == motif_mask.shape[0] == x_motif.shape[0]
        ), f"motif_index: {len(motif_index)}, motif_mask: {motif_mask.shape[0]}, x_motif: {x_motif.shape[0]}, lengths don't match"
        motif_index = torch.tensor(motif_index, dtype=torch.int64) - 1
        cur_mask[motif_index] = motif_mask

        cur_motif = torch.zeros((length, 37, 3), dtype=x_motif.dtype)
        cur_motif[motif_index] = x_motif
        cur_residue_type = torch.ones((length), dtype=torch.int64) * restype_num
        cur_residue_type[motif_index] = residue_type
        motif_masks.append(cur_mask)
        x_motifs.append(cur_motif)
        residue_types.append(cur_residue_type)

    return lengths, motif_masks, x_motifs, residue_types, out_strs


def save_motif_csv(pdb_path, motif_task_name, contigs, outpath=None, segment_order="A"):
    pdb_name = pdb_path.split("/")[-1].split(".")[0]

    data = [
        {
            "pdb_name": pdb_name,
            "sample_num": index,
            "contig": value,
            "redesign_positions": " ",
            "segment_order": segment_order,
        }
        for index, value in enumerate(contigs)
    ]

    df = pd.DataFrame(data)
    if outpath is None:
        outpath = f"./{motif_task_name}_motif_info.csv"

    df.to_csv(outpath, index=False)
