import torch
import torch.nn.functional as F
import pandas as pd
from affinityenhancer.data.utils.utils import get_noise

import logging
logger = logging.getLogger('gearnet_dataset')

from typing import Dict
import cpdb

from torchdrug.data import Protein
from affinityenhancer.data.datasets.cdr_utils import annotate_cdrs_for_sequence

### Dictionary of protein atom atomic numbers
ATOMIC_NUMS: Dict[str, int] = {
            "C": 6,
            "N": 7,
            "O": 8,
            "S": 16,
}
##MATCHES PROTEIN SETUP SEE BELOW
residue_symbol2id = {"G": 0, "A": 1, "S": 2, "P": 3, "V": 4, "T": 5, "C": 6, "I": 7, "L": 8, "N": 9,
                         "D": 10, "Q": 11, "K": 12, "E": 13, "M": 14, "H": 15, "F": 16, "R": 17, "Y": 18, "W": 19}

residue_symbol2id.update({"X": 20})
id2residue_symbol = {v: k for k, v in residue_symbol2id.items()}

'''
Reference for Protein object
https://torchdrug.ai/docs/api/data.html#torchdrug.data.PackedProtein
class Protein(Molecule):
    _meta_types = {"node", "edge", "residue", "graph",
                   "node reference", "edge reference", "residue reference", "graph reference"}
    dummy_protein = Chem.MolFromSequence("G")
    dummy_atom = dummy_protein.GetAtomWithIdx(0)

    # TODO: rdkit isn't compatible with X in the sequence
    residue2id = {"GLY": 0, "ALA": 1, "SER": 2, "PRO": 3, "VAL": 4, "THR": 5, "CYS": 6, "ILE": 7, "LEU": 8,
                  "ASN": 9, "ASP": 10, "GLN": 11, "LYS": 12, "GLU": 13, "MET": 14, "HIS": 15, "PHE": 16,
                  "ARG": 17, "TYR": 18, "TRP": 19}
    residue_symbol2id = {"G": 0, "A": 1, "S": 2, "P": 3, "V": 4, "T": 5, "C": 6, "I": 7, "L": 8, "N": 9,
                         "D": 10, "Q": 11, "K": 12, "E": 13, "M": 14, "H": 15, "F": 16, "R": 17, "Y": 18, "W": 19}
    atom_name2id = {"C": 0, "CA": 1, "CB": 2, "CD": 3, "CD1": 4, "CD2": 5, "CE": 6, "CE1": 7, "CE2": 8,
                    "CE3": 9, "CG": 10, "CG1": 11, "CG2": 12, "CH2": 13, "CZ": 14, "CZ2": 15, "CZ3": 16,
                    "N": 17, "ND1": 18, "ND2": 19, "NE": 20, "NE1": 21, "NE2": 22, "NH1": 23, "NH2": 24,
                    "NZ": 25, "O": 26, "OD1": 27, "OD2": 28, "OE1": 29, "OE2": 30, "OG": 31, "OG1": 32,
                    "OH": 33, "OXT": 34, "SD": 35, "SG": 36, "UNK": 37}
    alphabet2id = {c: i for i, c in enumerate(" " + string.ascii_uppercase + string.ascii_lowercase + string.digits)}
    id2residue = {v: k for k, v in residue2id.items()}
    id2residue_symbol = {v: k for k, v in residue_symbol2id.items()}
    id2atom_name = {v: k for k, v in atom_name2id.items()}
    id2alphabet = {v: k for k, v in alphabet2id.items()}
'''


def pdb_to_gearnet_protein(pdb_file, chain_map, add_noise=False, noise_var=0.005,
                           add_cdr=False
                           ) -> Protein:
        #str(self.DATA_DIR / f"{id}.pdb")
        df = cpdb.parse(pdb_file)

        # Exclude hetatms
        df = df.loc[df["record_name"] == "ATOM"]

        # Exclude hydrogens
        df = df[df["element_symbol"] != "H"]
        df = df[df["element_symbol"] != "D"]

        # Assign residue IDs
        df["residue_id"] = (
            df["chain_id"] + ":"
            + df["residue_name"] + ":"
            + df["residue_number"].astype(str) + ":"
            + df["insertion"]
        )

        # Get residue types
        res_types = df["residue_id"].unique()
        residue_type = torch.as_tensor(
            [Protein.residue2id[name.split(":")[1]] for name in res_types] # I think this modifaction is more robust to variable field lengths
        )
        res_id_map = {id: i for i, id in enumerate(res_types)}

        # Store position
        node_position = torch.as_tensor(df[["x_coord", "y_coord", "z_coord"]].values)
        if add_noise:
            node_position += get_noise(node_position, noise_var=noise_var)
        num_atom = node_position.shape[0]

        chains = torch.as_tensor([chain_map[id.split(":")[0]] for id in res_types]) # Same as the previous modification

        atom_residue_number = torch.as_tensor(df["residue_id"].map(res_id_map).tolist())
        
        atom_name = torch.as_tensor(
            [Protein.atom_name2id[name] for name in df["atom_name"]]
        )
        atom_type = torch.as_tensor(
            [ATOMIC_NUMS[name] for name in df["element_symbol"]]
        )
        #residue_nums = [i for i in range(residue_type.shape[0])]
        #print(residue_nums)
        
        # Placeholders, we overwrite with `self.transform`
        edge_list = torch.as_tensor([[0, 0, 0]])
        bond_type = torch.as_tensor([0])

        p = Protein(
            bond_type=bond_type,
            atom_type=atom_type,
            edge_list=edge_list,
            residue_type=residue_type,
            atom_name=atom_name,
            atom2residue=atom_residue_number,
            num_node=num_atom,
            num_residue=residue_type.shape[0],
            node_position=node_position,
            chain_id=chains,
            #residue_num=residue_nums
        )
        if add_cdr:
            chain_id = p.chain_id
            #total_len = chain_id.shape[0]
            heavy_len = chain_id[chain_id == 0].shape[0]
            light_len = chain_id[chain_id == 1].shape[0]

            heavy_mask = torch.zeros((heavy_len,)).long()
            light_mask = torch.zeros((light_len,)).long()
            #cdr_mask[:20] = 1
            #print(cdr_mask)
            heavy = protein_sequence_decoder([p.residue_type[:heavy_len].tolist()])
            light = protein_sequence_decoder([p.residue_type[heavy_len:].tolist()])
            indices_h, indices_l = annotate_cdrs_for_sequence(heavy, light)
            heavy_mask[indices_h] = 1
            light_mask[indices_l] = 1
            cdr_mask = torch.cat([heavy_mask, light_mask], dim=0)
            p.mol_feature = cdr_mask
        p.node_feature = F.one_hot(p.residue_type, num_classes=21).float()
        return p, node_position


def protein_sequence_decoder(sequences: list):
    return [''.join([id2residue_symbol[s] for s in seq]) for seq in sequences]


def protein_sequence_encoder(sequences: list):
    encoded = [[int(residue_symbol2id[s]) for s in seq] for seq in sequences]
    return encoded


def get_edges_from_gearnet_struct(struct):
    first_occurrences = torch.empty_like(struct.residue_type, dtype=torch.long)
    for i in range(struct.residue_type.shape[0]):
        first_occurrences[i] = (struct.atom2residue == i).nonzero(as_tuple=True)[0][0].item()
    pos = struct.node_position[first_occurrences, :]
    return torch.linalg.norm(pos.unsqueeze(0) - pos.unsqueeze(1), dim=-1)

        
        
