import torch
import dgl
import tempfile
import os
import numpy as np
from Bio import PDB
from Bio import SeqIO

from project.utils.prot_feats import GeometricProteinFeatures, gather_edges
from project.utils.saprot.foldseek_util import get_struc_seq

AMINO_ACIDS = [
    'A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 
    'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V', 'U'
]

ATOM_TYPES = [ 'N', 'CA', 'C', 'O', 'CB' ]

GPF = GeometricProteinFeatures(
    num_pos_embed=20,       
    num_rbf=18,
    dropout_rate=0.1,
)


def min_max_normalize_tensor(tensor: torch.Tensor, device=None):
    """Normalize provided tensor to have values be in range [0, 1]."""
    min_value = min(tensor)
    max_value = max(tensor)
    tensor = torch.tensor([(value - min_value) / (max_value - min_value) for value in tensor], device=device)
    return tensor


def amino_acid_one_hot(seq):

    aa_to_idx = { aa: idx for idx, aa in enumerate(AMINO_ACIDS) }

    indices = torch.tensor([aa_to_idx[aa] for aa in seq], dtype=torch.long)

    one_hot_encoded_seq = torch.nn.functional.one_hot(indices, num_classes=len(AMINO_ACIDS))

    return one_hot_encoded_seq


def get_edge_neighbor(graph, knn=20, geo_nbrhd_size=2):

    edges = graph.edges()
    src_node_in_edges, dst_node_in_edges = graph.in_edges(edges[0]), graph.in_edges(edges[1])
    src_node_in_edges = torch.cat((src_node_in_edges[0].reshape(-1, 1), src_node_in_edges[1].reshape(-1, 1)), dim=1)
    dst_node_in_edges = torch.cat((dst_node_in_edges[0].reshape(-1, 1), dst_node_in_edges[1].reshape(-1, 1)), dim=1)
    src_node_in_edges, dst_node_in_edges = src_node_in_edges.reshape(-1, knn, 2), dst_node_in_edges.reshape(-1, knn, 2)

    # Shuffle each KNN edge batch uniquely
    for batch_idx, knn_edge_batch in enumerate(src_node_in_edges):
        src_shuffled_edge_idx = torch.randperm(knn_edge_batch.size()[0])
        src_node_in_edges[batch_idx] = src_node_in_edges[batch_idx, src_shuffled_edge_idx]
    for batch_idx, knn_edge_batch in enumerate(dst_node_in_edges):
        dst_shuffled_edge_idx = torch.randperm(knn_edge_batch.size()[0])
        dst_node_in_edges[batch_idx] = dst_node_in_edges[batch_idx, dst_shuffled_edge_idx]
    src_node_in_edges = src_node_in_edges[:, :geo_nbrhd_size]
    dst_node_in_edges = dst_node_in_edges[:, :geo_nbrhd_size]
    # Derive edge IDs for randomly-selected neighboring edges
    src_e_ids = graph.edge_ids(torch.flatten(src_node_in_edges[:, :, 0]), torch.flatten(src_node_in_edges[:, :, 1]))
    dst_e_ids = graph.edge_ids(torch.flatten(dst_node_in_edges[:, :, 0]), torch.flatten(dst_node_in_edges[:, :, 1]))
    src_e_ids, dst_e_ids = src_e_ids.reshape(-1, geo_nbrhd_size), dst_e_ids.reshape(-1, geo_nbrhd_size)

    return src_e_ids, dst_e_ids


def extract_pdb_info(pdb_file):
    M = len(ATOM_TYPES)
    coords = np.zeros((0, M, 3)).astype('float32')
    
    parser = PDB.PDBParser()
    structure = parser.get_structure('structure', pdb_file)

    pdb_seq = ''
    nres = 0
    chain_ids = []
    for chain in structure[0]:
        chain_ids.append(chain.get_id())
        for residue in chain:
            if residue.get_resname() == 'UNK':
                print(f'Warning: Unknown amino acid in {pdb_file} at residue {pdb_seq}')
                continue
            res_name = PDB.Polypeptide.three_to_one(residue.get_resname())
            if res_name == '?': 
                print(pdb_file)
                print(residue.get_resname())
                print(pdb_seq)
                raise ValueError('Unknown amino acid')
            pdb_seq += res_name
            coords = np.concatenate((coords, np.zeros((1, M, 3))), axis=0)
            for atom in residue:
                if atom.get_name() in ATOM_TYPES:
                    coords[nres, ATOM_TYPES.index(atom.get_name()), :] = atom.get_coord()
            nres += 1
    
    return pdb_seq, coords, chain_ids


def process_graph_and_dist(seq, coords, GPF):

    # extract the coord of central atom, CA
    coords = torch.tensor(coords).unsqueeze(0).float() # [1, N, M, 3]
    ca_coords = coords[:, :, 1, :]

    # create the dgl graph
    N, K = len(seq), 20
    graph = dgl.knn_graph(ca_coords, K)
    edge_ids = graph.edges()[0].reshape(1, N, K)

    # calculate the pairwise distance of CA atom
    dist = torch.sqrt(
        torch.sum(
            (ca_coords.unsqueeze(2) - ca_coords.unsqueeze(1)) ** 2.0, dim=-1
        ) + 1e-7
    )
    dist_gather = gather_edges(dist.unsqueeze(-1), edge_ids) #[B, N, K, 1]

    # calculate the geometric features of node and edge
    bb_angles, dist_rbf, dir_ori, amide_angles = GPF(coords, dist_gather, edge_ids, graph.edges())

    # assignment the value of node
    graph.ndata['f'] = min_max_normalize_tensor(graph.nodes()).reshape(-1, 1) # node pos embedding 
    seq_one_hot = amino_acid_one_hot(seq).float()
    graph.ndata['f'] = torch.cat( [ 
        graph.ndata['f'], 
        seq_one_hot,
        bb_angles.squeeze() 
        ], dim=-1
    )

    # assigment the coords of CA
    graph.ndata['x'] = ca_coords.squeeze()
    graph.ndata['x_5'] = coords.squeeze()

    # assignment the value of edge
    graph.edata['f'] = torch.sin((graph.edges()[0] - graph.edges()[1]).float()).reshape(-1, 1) # edge pos embeding
    edge_weight = min_max_normalize_tensor(dist_gather.reshape(N*K, -1)).reshape(-1, 1) # edge weight
    graph.edata['f'] = torch.cat( [
        graph.edata['f'],
        edge_weight,
        dist_rbf.squeeze(0).reshape(N*K, -1),
        dir_ori.squeeze(0).reshape(N*K, -1),
        amide_angles,
        ], dim=-1
    )

    # assignment the neighbor of each edges
    src_e_ids, dst_e_ids = get_edge_neighbor(graph)
    graph.edata['src_nbr_e_ids'] = src_e_ids
    graph.edata['dst_nbr_e_ids'] = dst_e_ids

    return graph, dist

def gen_struct_seq(pdb_path, chain):
    with open(pdb_path, 'r') as file:
        content = file.read()

    #print(os.path.basename(pdb_path).split('.')[0])

    # Handle some non-standard amino acids
    modified = False
    if "HETATM" in content:
        content = content.replace("HETATM", "ATOM  ")
        modified = True
    if "MSE" in content:
        content = content.replace("MSE", "MET")
        modified = True
    if "CSO" in content:
        content = content.replace("CSO", "CYS")
        modified = True
    if "SEP" in content:
        content = content.replace("SEP", "SER")
        modified = True
    if "MLZ" in content:
        content = content.replace("MLZ", "LYS")
        modified = True
    if "ALY" in content:
        content = content.replace("ALY", "LYS")
        modified = True
    if "CME" in content:
        content = content.replace("CME", "CYS")
        modified = True

    if modified:
        with tempfile.NamedTemporaryFile(delete=False, mode='w+', suffix='.pdb', dir='./') as temp_file:
            temp_file.write(content)
            temp_file_path = temp_file.name
        #print(f"Creating tempfile for: {os.path.basename(pdb_path)}")
        pdb_path = temp_file_path

    # Extract the chain from the pdb file and encode it into a struc_seq
    # pLDDT is used to mask low-confidence regions if "plddt_mask" is True. Please set it to True when
    # use AF2 structures for best performance.
    parsed_seqs = get_struc_seq("utils/saprot/foldseek", pdb_path, [chain], plddt_mask=False)[chain]
    seq, foldseek_seq, combined_seq = parsed_seqs

    if modified:
        os.remove(temp_file_path)

    return combined_seq


def get_features(pdb_file):
    pdb_seq, coords, chain_ids = extract_pdb_info(pdb_file)
    graph, dist = process_graph_and_dist(pdb_seq, coords, GPF)

    if (len(chain_ids) > 1):
        raise ValueError(f'More than one chain found in PDB file: {chain_ids}')
    struct_seq = gen_struct_seq(pdb_file, chain_ids[0])

    return graph, dist, pdb_seq, struct_seq

def test():
    pdb_file = '/data_gu02/pclin/DeepGT/data/complex/5FGL/5FGL_A.pdb'
    fasta_file = '/data_gu02/pclin/DeepGT/data/complex/5FGL/5FGL_A.fasta'

    graph, dist, struct_seq = get_features(fasta_file, pdb_file)

    print(graph)
    print(dist)
    print(struct_seq)

# if __name__ == '__main__':
#     test()
