from pathlib import Path
import pandas as pd
import string
from typing import Union, Optional, Tuple, List, Dict, Callable, NamedTuple
import numpy as np
import torch
import torch.nn.functional as F
from scipy.spatial.distance import squareform, pdist
from Bio import SeqIO
from tqdm import tqdm
import contextlib
from torch.utils.data import Dataset


IUPAC_CODES = {
    "Ala": "A",
    "Arg": "R",
    "Asn": "N",
    "Asp": "D",
    "Cys": "C",
    "Gln": "Q",
    "Glu": "E",
    "Gly": "G",
    "His": "H",
    "Ile": "I",
    "Leu": "L",
    "Lys": "K",
    "Met": "M",
    "Phe": "F",
    "Pro": "P",
    "Ser": "S",
    "Thr": "T",
    "Trp": "W",
    "Val": "V",
    "Tyr": "Y",
    "Asx": "B",
    "Sec": "U",
    "Xaa": "X",
    "Glx": "Z",
}
ALPHABET = "ARNDCQEGHILKMFPSTWYV-"
A2N = {a: n for n, a in enumerate(ALPHABET)}
A2N["X"] = 20


def parse_PDB(x, atoms=["N", "CA", "C"], chain=None):
    """
    input:  x = PDB filename
            atoms = atoms to extract (optional)
    output: (length, atoms, coords=(x,y,z)), sequence
    """
    xyz, seq, min_resn, max_resn = {}, {}, np.inf, -np.inf
    for line in open(x, "rb"):
        line = line.decode("utf-8", "ignore").rstrip()

        if line[:6] == "HETATM" and line[17 : 17 + 3] == "MSE":
            line = line.replace("HETATM", "ATOM  ")
            line = line.replace("MSE", "MET")

        if line[:4] == "ATOM":
            ch = line[21:22]
            if ch == chain or chain is None:
                atom = line[12 : 12 + 4].strip()
                resi = line[17 : 17 + 3]
                resn = line[22 : 22 + 5].strip()
                x, y, z = [float(line[i : (i + 8)]) for i in [30, 38, 46]]

                if resn[-1].isalpha():
                    resa, resn = resn[-1], int(resn[:-1]) - 1
                else:
                    resa, resn = "", int(resn) - 1
                if resn < min_resn:
                    min_resn = resn
                if resn > max_resn:
                    max_resn = resn
                if resn not in xyz:
                    xyz[resn] = {}
                if resa not in xyz[resn]:
                    xyz[resn][resa] = {}
                if resn not in seq:
                    seq[resn] = {}
                if resa not in seq[resn]:
                    seq[resn][resa] = resi

                if atom not in xyz[resn][resa]:
                    xyz[resn][resa][atom] = np.array([x, y, z])

    # convert to numpy arrays, fill in missing values
    seq_, xyz_ = [], []
    for resn in range(min_resn, max_resn + 1):
        if resn in seq:
            for k in sorted(seq[resn]):
                seq_.append(IUPAC_CODES.get(seq[resn][k].capitalize(), "X"))
        else:
            seq_.append("X")
        if resn in xyz:
            for k in sorted(xyz[resn]):
                for atom in atoms:
                    if atom in xyz[resn][k]:
                        xyz_.append(xyz[resn][k][atom])
                    else:
                        xyz_.append(np.full(3, np.nan))
        else:
            for atom in atoms:
                xyz_.append(np.full(3, np.nan))

    valid_resn = np.array(sorted(xyz.keys()))
    return np.array(xyz_).reshape(-1, len(atoms), 3), "".join(seq_), valid_resn


def extend(a, b, c, L, A, D):
    """
    input:  3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
    output: 4th coord
    """

    def normalize(x):
        return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)

    bc = normalize(b - c)
    n = normalize(np.cross(b - a, bc))
    m = [bc, np.cross(n, bc), n]
    d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
    return c + sum([m * d for m, d in zip(m, d)])


def contacts_from_pdb(
    filename: str,
    distance_threshold: float = 8.0,
    chain: Optional[str] = None,
    distogram: bool = False,
) -> np.ndarray:
    atoms, sequence, resn = parse_PDB(filename, chain=chain)

    N = atoms[:, 0]
    CA = atoms[:, 1]
    C = atoms[:, 2]

    Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)
    dist = squareform(pdist(Cbeta))
    if distogram:
        return dist, atoms, sequence, resn
    else:
        contacts = dist < distance_threshold
        contacts = contacts.astype(np.long)
        contacts[np.isnan(dist)] = -1
        return contacts, atoms, sequence, resn


FastaOutput = NamedTuple("FastaOutput", [("headers", List[str]), ("seqs", List[str])])


def parse_fasta(
    filename: Union[str, Path],
    remove_insertions: bool = False,
    remove_gaps: bool = False,
) -> FastaOutput:

    filename = Path(filename)
    if filename.suffix == ".sto":
        form = "stockholm"
    elif filename.suffix in (".fas", ".fasta", ".a3m"):
        form = "fasta"
    else:
        raise ValueError(f"Unknown file format {filename.suffix}")

    translate_dict: Dict[str, Optional[str]] = {}
    if remove_insertions:
        translate_dict.update(dict.fromkeys(string.ascii_lowercase))
    else:
        translate_dict.update(dict(zip(string.ascii_lowercase, string.ascii_uppercase)))

    if remove_gaps:
        translate_dict["-"] = None

    translate_dict["."] = None
    translate_dict["*"] = None
    translation = str.maketrans(translate_dict)

    def process_record(record: SeqIO.SeqRecord) -> Tuple[str, str]:
        return record.description, str(record.seq).translate(translation)

    headers = []
    seqs = []

    for header, seq in map(process_record, SeqIO.parse(str(filename), form)):
        headers.append(header)
        seqs.append(seq)
    return FastaOutput(headers, seqs)


class CAMEODataset(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        super().__init__()
        data_dir = Path(data_dir)
        self.data_dir = data_dir

        pdbs = {x.stem: x for x in data_dir.glob("*.pdb")}
        a3ms = {x.stem: x for x in data_dir.glob("*.a3m")}
        assert len(pdbs) > 0

        self.pdb = pdbs
        self.a3m = a3ms

        self.keys = list(pdbs.keys())

    def __getitem__(self, idx):
        key = self.keys[idx]
        contacts, atoms, sequence, resn = contacts_from_pdb(self.pdb[key])
        msa = parse_fasta(self.a3m[key], remove_insertions=True)[1]
        contacts = torch.from_numpy(contacts)
        resn = torch.from_numpy(resn)
        valid_sequence_indices = torch.arange(resn.min(), resn.max() + 1)
        return {
            "id": key,
            "msa": msa,
            "contacts": contacts,
            "valid_sequence_indices": valid_sequence_indices,
        }

    def __len__(self):
        return len(self.keys)


DOMAIN_TO_LENGTH = {
    "T0950-D1": 342,
    "T0953s2-D2": 111,
    "T0953s2-D3": 93,
    "T0957s1-D1": 108,
    "T0957s2-D1": 155,
    "T0960-D2": 84,
    "T0963-D2": 82,
    "T0968s1-D1": 119,
    "T0968s2-D1": 116,
    "T0969-D1": 354,
    "T0975-D1": 293,
    "T0980s1-D1": 105,
    "T0981-D2": 80,
    "T0986s2-D1": 155,
    "T0987-D1": 185,
    "T0987-D2": 207,
    "T0989-D1": 134,
    "T0989-D2": 112,
    "T0990-D1": 76,
    "T0990-D2": 231,
    "T0990-D3": 213,
    "T0991-D1": 111,
    "T0998-D1": 166,
    "T1000-D2": 431,
    "T1001-D1": 139,
    "T1010-D1": 210,
    "T1015s1-D1": 88,
    "T1017s2-D1": 128,
    "T1021s3-D1": 178,
    "T1021s3-D2": 101,
    "T1022s1-D1": 156,
}


class CASP13Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        super().__init__()
        data_dir = Path(data_dir)
        self.data_dir = data_dir

        fm_file = self.data_dir / "FM"
        a3m_path = data_dir / "a3m"
        pdb_path = data_dir / "natives"

        def get_fnames(line: str) -> Tuple[Path, Path]:
            target, domain = line.split()
            return (
                (a3m_path / target).with_suffix(".a3m"),
                (pdb_path / domain).with_suffix(".pdb"),
            )

        with fm_file.open() as f:
            self.keys = [
                (target, domain)
                for target, domain in map(get_fnames, f)
                if (target.exists() and domain.exists())
            ]
        assert len(self.keys) > 0
        self.targets = [target_fname.stem for target_fname, _ in self.keys]
        self.domains = [domain_fname.stem for _, domain_fname in self.keys]

    def get(self, key: str, identifier: str = "target"):
        if identifier == "target":
            idx = self.targets.index(key)
            return self[idx]
        elif identifier == "domain":
            idx = self.domains.index(key)
            return self[idx]
        else:
            raise ValueError(f"Unknown identifier: {identifier}")

    def __getitem__(self, idx):
        target_fname, domain_fname = self.keys[idx]
        domain = domain_fname.stem
        contacts, atoms, sequence, resn = contacts_from_pdb(
            domain_fname, distogram=False
        )
        msa = parse_fasta(target_fname, remove_insertions=True)[1]
        contacts = torch.from_numpy(contacts)
        resn = torch.from_numpy(resn)
        valid_sequence_indices = torch.arange(resn.min(), resn.max() + 1)
        domain_length = DOMAIN_TO_LENGTH[domain]
        return {
            "id": domain,
            "domain": domain,
            "target": target_fname.stem,
            "msa": msa,
            "contacts": contacts,
            "valid_sequence_indices": valid_sequence_indices,
            "atoms": atoms,
            "domain_length": domain_length,
        }

    def __len__(self):
        return len(self.keys)


def lr_p_at_ell(pred, contacts, idx):
    # pred should be (ell, ell) logits
    # contacts should be (n_valid, n_valid)
    # idx should be (n_valid, )
    # topk = len(pred) # top L precision
    minsep = 24
    n_valid = len(contacts)
    topk = n_valid
    # pred = torch.softmax(pred, dim=-1)[:, :, 1:14].sum(axis=-1) # n_valid, n_valid
    pred = pred[idx][:, idx]

    x_ind, y_ind = np.triu_indices(n_valid, 1)
    predictions_upper = pred[x_ind, y_ind] # n_pred
    targets_upper = contacts[x_ind, y_ind]

    dist = pdist(idx.unsqueeze(1))
    m = torch.tensor(dist >= minsep) & (targets_upper >=0)
    if m.sum() == 0:
        return 0, 0
    predictions_upper = predictions_upper[m]
    targets_upper = targets_upper[m]

    # topk = min(topk, len(predictions_upper))
    indices = predictions_upper.argsort(descending=True)[:topk]
    topk_targets = targets_upper[indices]
    return topk_targets.sum().item(), len(topk_targets)
