import sys
import warnings

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
import numpy as np
import pandas as pd
from fragment.dataset import PDB
from fragment import mol_unit_sphere
from fragment.protein_fragments import constants
from fragment.aligment2.torch_canon.pointcloud import CanonEn as Canon
import tqdm
import time

one_letter_to_number = {
    "A": 1,  # Alanine
    "R": 2,  # Arginine
    "N": 3,  # Asparagine
    "D": 4,  # Aspartic acid
    "C": 5,  # Cysteine
    "E": 6,  # Glutamic acid
    "Q": 7,  # Glutamine
    "G": 8,  # Glycine
    "H": 9,  # Histidine
    "I": 10, # Isoleucine
    "L": 11, # Leucine
    "K": 12, # Lysine
    "M": 13, # Methionine
    "F": 14, # Phenylalanine
    "P": 15, # Proline
    "S": 16, # Serine
    "T": 17, # Threonine
    "W": 18, # Tryptophan
    "Y": 19, # Tyrosine
    "V": 20, # Valine
}


class CustomData(Data):
    def __inc__(self, key, value, *args, **kwargs):
        # Adjust `mapping_a_to_b` based on the cumulative node count of graph B
        if key == "mapping_a_to_b" or key == "ch_b_edge_index":
            return self.num_nodes_b  # Offset by the number of nodes in B
        # Default behavior for other attributes
        return super().__inc__(key, value, *args, **kwargs)


def _normalize(tensor, dim=-1):
    '''
    Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
    '''
    return torch.nan_to_num(
        torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))


def compute_dihedrals(v1, v2, v3):
    n1 = torch.cross(v1, v2)
    n2 = torch.cross(v2, v3)
    a = (n1 * n2).sum(dim=-1)
    b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1))
    torsion = torch.nan_to_num(torch.atan2(b, a))
    return torsion


def get_bb_embs(X):
    # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue
    # N coords: X[:,0,:]
    # CA coords: X[:,1,:]
    # C coords: X[:,2,:]
    # return num_residues x 6
    # From https://github.com/jingraham/neurips19-graph-protein-design

    X = torch.reshape(X, [3 * X.shape[0], 3])
    dX = X[1:] - X[:-1]
    U = _normalize(dX, dim=-1)
    u0 = U[:-2]
    u1 = U[1:-1]
    u2 = U[2:]

    angle = compute_dihedrals(u0, u1, u2)

    # add phi[0], psi[-1], omega[-1] with value 0
    angle = F.pad(angle, [1, 2])
    angle = torch.reshape(angle, [-1, 3])
    angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1)
    return angle_features


def pdb_info_to_graph(df, edges, s_graph_dict):
    data = CustomData()

    ca_values = df['CA'].values
    data.coords_a_ca = torch.tensor(np.vstack(ca_values), dtype=torch.float32)
    pos_n = df['N'].values
    pos_n = torch.tensor(np.vstack(pos_n), dtype=torch.float32)
    pos_ca = torch.tensor(np.vstack(ca_values), dtype=torch.float32)
    pos_c = df['C'].values
    pos_c = torch.tensor(np.vstack(pos_c), dtype=torch.float32)
    fea_a = df[['asa','phi', 'psi', 'NH_O_1_relidx',
                'NH_O_1_energy', 'O_NH_1_relidx',
                'O_NH_1_energy', 'NH_O_2_relidx',
                'NH_O_2_energy', 'O_NH_2_relidx',
                'O_NH_2_energy']].values
    data.side_chain_embs =torch.tensor(np.vstack(fea_a), dtype=torch.float32)
    amino_types = df['aa'].apply(lambda x: one_letter_to_number.get(x, -1)).values
    df['amino_types'] = amino_types
    canonicalization = Canon(tol=1e-6, save = False)
    ct = 0
    len_ss = len(df['ss_ser'].unique())
    frame_R_ts = torch.zeros(len_ss, 3, 3)
    for ss_ser in df['ss_ser'].unique():
        sub_df = df.loc[df['ss_ser'] == ss_ser]
        coords_matrix = np.array([list(row) for row in sub_df.CA.values], dtype=np.float32)
        residue_names_matrix = sub_df.amino_types.values
        if len(residue_names_matrix)>1:
            normalized_positions, frame_R, frame_t = canonicalization.get_frame(coords_matrix, residue_names_matrix)
            normalized_positions = torch.tensor(normalized_positions, dtype=torch.float32)
            frame_R = frame_R.clone().detach().float()
            frame_t = frame_t.clone().detach().float().reshape(1, 3)
        else:
            normalized_positions = torch.tensor(coords_matrix, dtype=torch.float32)
            normalized_positions = torch.zeros_like(normalized_positions)
            frame_R = torch.eye(3)
            frame_t = torch.tensor(coords_matrix, dtype=torch.float32).reshape(1,3)
        frame_R_ts[ct] = frame_R
        if ct == 0:
            a_normalized_positions = normalized_positions
            b_frame_t = frame_t
        else:
            a_normalized_positions = torch.cat((a_normalized_positions, normalized_positions), 0)
            b_frame_t = torch.cat((b_frame_t, frame_t), 0)
        ct += 1

    data.x = torch.tensor(amino_types, dtype=torch.float32)

    # three backbone torsion angles
    bb_embs = get_bb_embs(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1))
    bb_embs[torch.isnan(bb_embs)] = 0
    data.bb_embs = bb_embs
    data.coords_a_n = pos_n
    data.coords_a_c = pos_c

    # add edges
    data.edge_index = torch.tensor(edges, dtype=torch.long).contiguous()

    # add secondary structure graph
    coords_b_ = torch.tensor(s_graph_dict['coords'], dtype=torch.float32)
    data.coords_b_ = coords_b_

    ss_cls = torch.tensor(s_graph_dict['ss_num'], dtype=torch.float32)
    data.ss_x = ss_cls

    num_nodes_b = len(coords_b_)
    data.num_nodes_b = num_nodes_b

    # mapping_a_to_b
    mapping_a_to_b = torch.tensor(df['ss_ser'].values, dtype=torch.long).contiguous()
    data.mapping_a_to_b = mapping_a_to_b
    data.b_frame_R_ts = frame_R_ts
    data.b_frame_t = b_frame_t
    data.a_normalized_positions = a_normalized_positions


    assert len(data.x)==len(data.coords_a_ca)==len(data.coords_a_n)==len(data.coords_a_c)==len(data.side_chain_embs)==len(data.bb_embs)

    return data


def process(pdb_id, case_id, chain_function, split):

    raw_protein_file = constants.RAW_DATA_DIR/f'{pdb_id}.pdb'
    if not raw_protein_file.exists():
        print(f'{raw_protein_file.name} does not exist!')
        return

    # with warnings.catch_warnings():
    #     warnings.simplefilter("ignore")
    frame = mol_unit_sphere.Frame()
    try:
        df, edges, s_graph_dict = PDB.get_pdb_info_EC(raw_protein_file, case_id)
        curProtein = pdb_info_to_graph(df, edges, s_graph_dict)
        if curProtein.x is None:
            print(f"pdb graph's x of {pdb_id} is None!")
            return

        curProtein.id = f'{pdb_id}.{case_id}'
        chain_function = int(chain_function)
        curProtein.y = torch.tensor(chain_function)

        # schull edges
        pos = curProtein.coords_b_
        _, shell_data_ch, edge_index_hull = frame.get_frame(pos.numpy())
        ch_pos = torch.tensor(shell_data_ch, dtype=torch.float)
        ch_r = torch.norm(ch_pos - torch.mean(ch_pos, dim=0), dim=-1)

        curProtein['ch_b_pos'] = ch_pos
        curProtein['ch_b_r'] = ch_r
        curProtein['ch_b_edge_index'] = torch.tensor(edge_index_hull, dtype=torch.long)

        torch.save(curProtein, constants.PROCESSED_DATA_DIR/split/f'{curProtein.id}.pt')

    except Exception as e:
        print(f"Error in processing {pdb_id}: {e}")
        return


if __name__ == "__main__":
    assert len(sys.argv) == 4
    assert sys.argv[1] in ('training', 'val', 'test')
    process(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
    # df = pd.read_csv(constants.METADATA_DIR/'ECDataset'/'training_with_chain_functions.csv')
    # len_df = len(df)
    # start_time = time.time()
    # for i, row in df.iterrows():
    #     process(row['pdb_id'], row['case_id'], row['chain_function'])
    #     time_elapsed = time.time() - start_time
    #     ETA = time_elapsed/(i+1) * (len_df - i - 1)
    #     print(f'{i+1}/{len_df} processed in {time_elapsed:.2f}s, ETA: {ETA:.2f}s')
