import os, sys
from typing import Dict, Any
from argparse import ArgumentParser
from datetime import datetime
import pytorch_lightning as pl
import ml_collections as mlc
from types import ModuleType
import pickle
import torch
import re
import numpy as np
from openfold.np import protein
import string
from entity import (
    entity_constants as ec,
    residue_constants as rc
)
from openfold.utils.tensor_utils import tensor_tree_map
from rdkit import Chem
from Bio.PDB import PDBIO, Atom, Chain, Residue, Model
from io import StringIO
import copy
pse = Chem.GetPeriodicTable()

def to_pdb(prot) -> str:
    """Converts a `Protein` instance to a PDB string.

    Args:
      prot: The protein to convert to PDB.

    Returns:
      PDB string.
    """
    restypes = rc.restypes + ["X"]
    res_1to3 = lambda r: rc.restype_1to3.get(restypes[r], "UNK")
    atom_types = rc.atom_types

    pdb_lines = []

    atom_mask = prot.atom_mask
    aatype = prot.aatype
    atom_positions = prot.atom_positions
    residue_index = prot.residue_index.astype(np.int32)
    b_factors = prot.b_factors
    chain_index = prot.chain_index

    if np.any(aatype > rc.restype_num):
        raise ValueError("Invalid aatypes.")

    headers = protein.get_pdb_headers(prot)
    if(len(headers) > 0):
        pdb_lines.extend(headers)

    n = aatype.shape[0]
    atom_index = 1
    prev_chain_index = 0
    chain_tags = string.ascii_uppercase
    # Add all atom sites.
    for i in range(n):
        res_name_3 = res_1to3(aatype[i])
        for atom_name, pos, mask, b_factor in zip(
            atom_types, atom_positions[i], atom_mask[i], b_factors[i]
        ):
            if mask < 0.5:
                continue

            record_type = "ATOM"
            name = atom_name if len(atom_name) == 4 else f" {atom_name}"
            alt_loc = ""
            insertion_code = ""
            occupancy = 1.00
            element = atom_name[
                0
            ]  # Protein supports only C, N, O, S, this works.
            charge = ""

            chain_tag = "A"
            if(chain_index is not None):
                chain_tag = chain_tags[chain_index[i]]

            # PDB is a columnar format, every space matters here!
            atom_line = (
                f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
                f"{res_name_3:>3} {chain_tag:>1}"
                f"{residue_index[i]:>4}{insertion_code:>1}   "
                f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
                f"{occupancy:>6.2f}{b_factor:>6.2f}          "
                f"{element:>2}{charge:>2}"
            )
            pdb_lines.append(atom_line)
            atom_index += 1

        should_terminate = (i == n - 1)
        if(chain_index is not None):
            if(i != n - 1 and chain_index[i + 1] != prev_chain_index):
                should_terminate = True
                prev_chain_index = chain_index[i + 1]

        if(should_terminate):
            # Close the chain.
            chain_end = "TER"
            chain_termination_line = (
                f"{chain_end:<6}{atom_index:>5}      "
                f"{res_1to3(aatype[i]):>3} "
                f"{chain_tag:>1}{residue_index[i]:>4}"
            )
            pdb_lines.append(chain_termination_line)
            atom_index += 1

            if(i != n - 1):
                # "prev" is a misnomer here. This happens at the beginning of
                # each new chain.
                pdb_lines.extend(protein.get_pdb_headers(prot, prev_chain_index))

    return "\n".join(pdb_lines)

def create_full_prot(
        atom37: np.ndarray,
        atom37_mask: np.ndarray,
        aatype=None,
        b_factors=None,
    ):
    assert atom37.ndim == 3
    assert atom37.shape[-1] == 3
    assert atom37.shape[-2] == 37
    n = atom37.shape[0]
    residue_index = np.arange(n, dtype=int)
    chain_index = np.zeros(n, dtype=int)
    if b_factors is None:
        b_factors = np.zeros([n, 37])
    if aatype is None:
        aatype = np.zeros(n, dtype=int)
    return protein.Protein(
        atom_positions=atom37,
        atom_mask=atom37_mask,
        aatype=aatype,
        residue_index=residue_index,
        chain_index=chain_index,
        b_factors=b_factors)

def write_prot_to_pdb(
        prot_pos: np.ndarray,
        file_path: str,
        aatype: np.ndarray=None,
        overwrite=False,
        no_indexing=False,
        b_factors=None,
    ):
    if overwrite:
        max_existing_idx = 0
    else:
        file_dir = os.path.dirname(file_path)
        os.makedirs(file_dir, exist_ok=True)
        file_name = os.path.basename(file_path).strip('.pdb')
        existing_files = [x for x in os.listdir(file_dir) if file_name in x]
        max_existing_idx = max([
            int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x)
            if re.findall(r'_(\d+).pdb', x)] + [0])
    if not no_indexing:
        save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb'
    else:
        save_path = file_path
    with open(save_path, 'w') as f:
        if prot_pos.ndim == 4:
            for t, pos37 in enumerate(prot_pos):
                atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7
                prot = create_full_prot(
                    pos37, atom37_mask, aatype=aatype, b_factors=b_factors)
                pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False)
                f.write(pdb_prot)
        elif prot_pos.ndim == 3:
            atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7
            prot = create_full_prot(
                prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors)
            pdb_prot = protein.to_pdb(prot)
            pdb_prot = f'MODEL     {1}\n' + pdb_prot
            f.write(pdb_prot)
        else:
            raise ValueError(f'Invalid positions shape {prot_pos.shape}')
        # f.write('END')
    return save_path

def prot_to_pdb_block(
        prot_pos: np.ndarray,
        aatype: np.ndarray=None,
        overwrite=False,
        no_indexing=False,
        b_factors=None,
        model_id=1,
    ):
    ret = ""
    if prot_pos.ndim == 4:
        for t, pos37 in enumerate(prot_pos):
            atom37_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7
            prot = create_full_prot(
                pos37, atom37_mask, aatype=aatype, b_factors=b_factors)
            pdb_prot = protein.to_pdb(prot, model=t + 1, add_end=False)
            ret += pdb_prot
    elif prot_pos.ndim == 3:
        atom37_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7
        prot = create_full_prot(
            prot_pos, atom37_mask, aatype=aatype, b_factors=b_factors)
        pdb_prot = protein.to_pdb(prot)
        pdb_prot = f'MODEL     {model_id}\n' + pdb_prot
        ret += pdb_prot
    else:
        raise ValueError(f'Invalid positions shape {prot_pos.shape}')
    return ret

def write_lig_to_pdb(
        lig_pos: np.ndarray,
        tok_type: np.ndarray,
        file_path: str,
        overwrite=False,
        no_indexing=False,
        b_factors=None
):
    if overwrite:
        max_existing_idx = 0
    else:
        file_dir = os.path.dirname(file_path)
        os.makedirs(file_dir, exist_ok=True)
        file_name = os.path.basename(file_path).strip('.pdb')
        existing_files = [x for x in os.listdir(file_dir) if file_name in x]
        max_existing_idx = max([
            int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x)
            if re.findall(r'_(\d+).pdb', x)] + [0])
    if not no_indexing:
        save_path = file_path.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb'
    else:
        save_path = file_path
    # 创建一个Structure对象
    structure_id = "1" # 结构ID，可根据需要自定义
    structure = Model.Model(structure_id)
    residue_id = (' ', 1, ' ') # 残基ID，格式为(异构体标识符, 残基序号, 插入代码)
    residue = Residue.Residue(residue_id, "LIG", '') # "UNK"代表未知残基类型

    # 创建一个链对象，PDB中每个Structure包含一个或多个Chain
    chain_id = "B" # 链ID，可根据需要自定义
    chain = Chain.Chain(chain_id)
    num_atoms = lig_pos.shape[0]
    for i in range(num_atoms):
        x, y, z = lig_pos[i][0], lig_pos[i][1], lig_pos[i][2]
        atom_type = pse.GetElementSymbol(int(ec.token_types[tok_type[i]][1:]))
        atom = Atom.Atom(
            name=atom_type+str(i),
            coord=[x, y, z], 
            bfactor=0.0, 
            occupancy=1.0,
            altloc=" ",
            fullname=atom_type+str(i),
            serial_number=i,
            element=atom_type
        )
        residue.add(atom)
    chain.add(residue)

    structure.add(chain)

    # 使用PDBIO将结构写入文件
    io = PDBIO()
    io.set_structure(structure)
    io.save(save_path)

    return save_path

def lig_to_pdb_block(
    lig_pos: np.ndarray,
    tok_type: np.ndarray,
    edges: np.ndarray,
    chain_id: str,
    serial_offset,
    res_offset,
    overwrite=False,
    no_indexing=False,
    b_factors=None
):
    # 创建一个Structure对象
    structure_id = "1" # 结构ID，可根据需要自定义
    structure = Model.Model(structure_id)
    residue_id = (' ', 1, ' ') # 残基ID，格式为(异构体标识符, 残基序号, 插入代码)
    residue = Residue.Residue(residue_id, "LIG", '') # "UNK"代表未知残基类型

    # 创建一个链对象，PDB中每个Structure包含一个或多个Chain
    # chain_id = "B" # 链ID，可根据需要自定义
    chain = Chain.Chain(chain_id)
    num_atoms = lig_pos.shape[0]
    for i in range(num_atoms):
        x, y, z = lig_pos[i][0], lig_pos[i][1], lig_pos[i][2]
        atom_type = pse.GetElementSymbol(int(ec.token_types[tok_type[i]][1:]))
        atom = Atom.Atom(
            name=atom_type+str(i),
            coord=[x, y, z], 
            bfactor=0.0, 
            occupancy=1.0,
            altloc=" ",
            fullname=atom_type+str(i),
            serial_number=serial_offset+i+1,
            element=atom_type
        )
        residue.add(atom)
    chain.add(residue)

    structure.add(chain)

    # 使用PDBIO将结构写入字符串
    io = PDBIO()
    pdb_string = StringIO()
    io.set_structure(structure)
    io.save(pdb_string, preserve_atom_numbering=True)
    pdb_string.seek(0)
    pdb_string = pdb_string.getvalue()

    pdb_string = '\n'.join(pdb_string.split('\n')[:-3]) + '\n'
    pdb_string = pdb_string.replace("ATOM  ", "HETATM")

    for i in range(num_atoms):
        for j in range(i+1, num_atoms):
            if edges[i + res_offset, j + res_offset] == 1:
                pdb_string += f"CONECT{i+serial_offset+1:5d}{j+serial_offset+1:5d}\n"

    return pdb_string

def write_inference_result(eval_dir, feats, infer_out, name, subdir=None):
    seq_length = feats["seq_length"].cpu().detach().numpy()
    fixed_mask = feats["fixed_mask"].float().cpu().detach().numpy()
    seq_mask = feats["seq_mask"].float().cpu().detach().numpy().astype(bool)
    all_atom_mask = feats["all_atom_mask"].float().cpu().detach().numpy()
    ca_idx = ec.atom_order["CA"]
    molAtom_idx = ec.atom_order["*MolAtom"]
    bb_mask = all_atom_mask[..., ca_idx].astype(bool)
    mol_mask = all_atom_mask[..., molAtom_idx].astype(bool)
    token_type = feats["token_type"].cpu().detach().numpy()
    edges = feats["edges"].cpu().detach().numpy().max(axis=-1)

    if subdir:
        eval_dir = os.path.join(eval_dir, subdir)

    os.makedirs(os.path.join(eval_dir, f"traj_pdb"), exist_ok=True)
    os.makedirs(os.path.join(eval_dir, f"final_pdb"), exist_ok=True)
    os.makedirs(os.path.join(eval_dir, f"genie_npy"), exist_ok=True)
    os.makedirs(os.path.join(eval_dir, f"info"), exist_ok=True)
    os.makedirs(os.path.join(eval_dir, f"denoise_traj"), exist_ok=True)
    
    for i in range(bb_mask.shape[0]):
        pdb_string = ""
        for traj_id in range(len(infer_out['prot_traj'])):
            final_prot = infer_out['prot_traj'][traj_id]
            num_res = seq_length[i]
            unpad_fixed_mask = fixed_mask[i][seq_mask[i]]
            unpad_diffused_mask = 1 - unpad_fixed_mask
            unpad_prot = final_prot[i][bb_mask[i]][..., :37, :]
            percent_diffused = np.sum(unpad_diffused_mask) / num_res

            mol_coord = final_prot[i][mol_mask[i]][..., molAtom_idx, :]
            mol_type = token_type[i][mol_mask[i]]

            # Extract argmax predicted aatype
            pdb_string_prot = prot_to_pdb_block(
                unpad_prot,
                no_indexing=True,
                b_factors=np.tile(1 - unpad_fixed_mask[..., None], 37) * 100,
                model_id=traj_id
            )

            pdb_string_mol = lig_to_pdb_block(
                mol_coord,
                mol_type,
                edges[i],
                chain_id="B",
                serial_offset=len(pdb_string_prot.split('\n')) - 5,
                res_offset=unpad_prot.shape[0],
                no_indexing=True,
                b_factors=np.tile(1 - unpad_fixed_mask[..., None], 37) * 100
            )

            pdb_string += '\n'.join(pdb_string_prot.split('\n')[:-3]) + '\n' + pdb_string_mol

            if traj_id == 0:
                ca_coords = unpad_prot[..., ec.atom_order["CA"], :]

                saved_path = write_prot_to_pdb(
                    unpad_prot,
                    os.path.join(
                        os.path.join(eval_dir, f"final_pdb"),
                        f'{name[i]}_protein.pdb'
                    ),
                    no_indexing=True,
                    b_factors=np.tile(1 - unpad_fixed_mask[..., None], 37) * 100
                )

                saved_path_mol = write_lig_to_pdb(
                    mol_coord,
                    mol_type,
                    os.path.join(
                        os.path.join(eval_dir, f"final_pdb"),
                        f'{name[i]}_mol.pdb'
                    ),
                    no_indexing=True,
                    b_factors=np.tile(1 - unpad_fixed_mask[..., None], 37) * 100
                )

                with open(
                    os.path.join(
                        eval_dir, f"final_pdb", f"{name[i]}.pdb"
                    ),
                    "w"
                ) as f:
                    f.write(pdb_string)

                np.savetxt(
                    os.path.join(
                        os.path.join(eval_dir, f"genie_npy"),
                        f'{name[i]}.npy'
                    ),
                    ca_coords,
                    fmt='%.3f',
                    delimiter=','
                )

        save_path = os.path.join(
            os.path.join(eval_dir, f"traj_pdb"),
            f'{name[i]}.pdb'
        )

        with open(os.path.join(
            os.path.join(eval_dir, f"info"),
            f"{name[i]}.txt"
        ), "w") as f:
            f.write(f"{infer_out['info']}")

        with open(os.path.join(
            os.path.join(eval_dir, f"denoise_traj"),
            f"{name[i]}.pkl"
        ), "wb") as f:
            pickle.dump(infer_out['denoise_traj'], f)

        with open(save_path, 'w') as f:
            f.write(pdb_string)

    # raise NotImplementedError

def get_generation_pdb(feats, infer_out):
    seq_length = feats["seq_length"].cpu().detach().numpy()
    fixed_mask = feats["fixed_mask"].float().cpu().detach().numpy()
    seq_mask = feats["seq_mask"].float().cpu().detach().numpy().astype(bool)
    all_atom_mask = feats["all_atom_mask"].float().cpu().detach().numpy()
    ca_idx = ec.atom_order["CA"]
    molAtom_idx = ec.atom_order["*MolAtom"]
    bb_mask = all_atom_mask[..., ca_idx].astype(bool)
    mol_mask = all_atom_mask[..., molAtom_idx].astype(bool)
    token_type = feats["token_type"].cpu().detach().numpy()
    edges = feats["edges"].cpu().detach().numpy().max(axis=-1)
    
    for i in range(bb_mask.shape[0]):
        pdb_string = ""
        for traj_id in range(len(infer_out['prot_traj'])):
            final_prot = infer_out['prot_traj'][traj_id]
            num_res = seq_length[i]
            unpad_fixed_mask = fixed_mask[i][seq_mask[i]]
            unpad_diffused_mask = 1 - unpad_fixed_mask
            unpad_prot = final_prot[i][bb_mask[i]][..., :37, :]
            percent_diffused = np.sum(unpad_diffused_mask) / num_res

            mol_coord = final_prot[i][mol_mask[i]][..., molAtom_idx, :]
            mol_type = token_type[i][mol_mask[i]]

            # Extract argmax predicted aatype
            pdb_string_prot = prot_to_pdb_block(
                unpad_prot,
                no_indexing=True,
                b_factors=np.tile(1 - unpad_fixed_mask[..., None], 37) * 100,
                model_id=traj_id
            )

            pdb_string_mol = lig_to_pdb_block(
                mol_coord,
                mol_type,
                edges[i],
                chain_id="B",
                serial_offset=len(pdb_string_prot.split('\n')) - 5,
                res_offset=unpad_prot.shape[0],
                no_indexing=True,
                b_factors=np.tile(1 - unpad_fixed_mask[..., None], 37) * 100
            )

            pdb_string += '\n'.join(pdb_string_prot.split('\n')[:-3]) + '\n' + pdb_string_mol

            return pdb_string
