import os
import random
import logging
import datetime
import pandas as pd
import joblib
import pickle
import lmdb
import subprocess
import torch
import h5py
import model.utils.parsers as parsers
import model.utils.chemical as chemical
from model.utils.torsion import get_torsion_angle
from model.utils.geometry import center_and_realign_missing
from Bio import PDB, SeqRecord, SeqIO, Seq
from Bio.PDB import PDBExceptions
from torch.utils.data import Dataset
from tqdm.auto import tqdm

ALLOWED_AG_TYPES = {
    'protein',
    'protein | protein',
    'protein | protein | protein',
    'protein | protein | protein | protein | protein',
    'protein | protein | protein | protein',
}

RESOLUTION_THRESHOLD = 4.0

TEST_ANTIGENS = [
    'sars-cov-2 receptor binding domain',
    'hiv-1 envelope glycoprotein gp160',
    'mers s',
    'influenza a virus',
    'cd27 antigen',
]


def nan_to_empty_string(val):
    if val != val or not val:
        return ''
    else:
        return val


def nan_to_none(val):
    if val != val or not val:
        return None
    else:
        return val


def split_sabdab_delimited_str(val):
    if not val:
        return []
    else:
        return [s.strip() for s in val.split('|')]


def parse_sabdab_resolution(val):
    if val == 'NOT' or not val or val != val:
        return None
    elif isinstance(val, str) and ',' in val:
        return float(val.split(',')[0].strip())
    else:
        return float(val)


def _aa_tensor_to_sequence(aa):
    # return ''.join([Polypeptide.index_to_one(a.item()) for a in aa.flatten()])
    return ''.join([chemical.aa2num[(a.item())] for a in aa.flatten()])


def _label_heavy_chain_cdr(data, seq_map, max_cdr3_length=30):
    if data is None or seq_map is None:
        return data, seq_map

    # Add CDR labels
    cdr_flag = torch.zeros_like(data['aa'])
    for position, idx in seq_map.items():
        resseq = position[1]
        cdr_type = chemical.ChothiaCDRRange.to_cdr('H', resseq)
        if cdr_type is not None:
            cdr_flag[idx] = cdr_type
    data['cdr_flag'] = cdr_flag

    # Add CDR sequence annotations
    data['H1_seq'] = _aa_tensor_to_sequence(data['aa'][cdr_flag == chemical.CDR.H1])
    data['H2_seq'] = _aa_tensor_to_sequence(data['aa'][cdr_flag == chemical.CDR.H2])
    data['H3_seq'] = _aa_tensor_to_sequence(data['aa'][cdr_flag == chemical.CDR.H3])

    cdr3_length = (cdr_flag == chemical.CDR.H3).sum().item()
    # Remove too long CDR3
    if cdr3_length > max_cdr3_length:
        cdr_flag[cdr_flag == chemical.CDR.H3] = 0
        logging.warning(f'CDR-H3 too long {cdr3_length}. Removed.')
        return None, None

    # Filter: ensure CDR3 exists
    if cdr3_length == 0:
        logging.warning('No CDR-H3 found in the heavy chain.')
        return None, None

    return data, seq_map


def _label_light_chain_cdr(data, seq_map, max_cdr3_length=30):
    if data is None or seq_map is None:
        return data, seq_map
    cdr_flag = torch.zeros_like(data['aa'])
    for position, idx in seq_map.items():
        resseq = position[1]
        cdr_type = chemical.ChothiaCDRRange.to_cdr('L', resseq)
        if cdr_type is not None:
            cdr_flag[idx] = cdr_type
    data['cdr_flag'] = cdr_flag

    data['L1_seq'] = _aa_tensor_to_sequence(data['aa'][cdr_flag == chemical.CDR.L1])
    data['L2_seq'] = _aa_tensor_to_sequence(data['aa'][cdr_flag == chemical.CDR.L2])
    data['L3_seq'] = _aa_tensor_to_sequence(data['aa'][cdr_flag == chemical.CDR.L3])

    cdr3_length = (cdr_flag == chemical.CDR.L3).sum().item()
    # Remove too long CDR3
    if cdr3_length > max_cdr3_length:
        cdr_flag[cdr_flag == chemical.CDR.L3] = 0
        logging.warning(f'CDR-L3 too long {cdr3_length}. Removed.')
        return None, None

    # Ensure CDR3 exists
    if cdr3_length == 0:
        logging.warning('No CDRs found in the light chain.')
        return None, None

    return data, seq_map


def preprocess_sabdab_structure(task):
    entry = task['entry']
    pdb_path = task['pdb_path']

    parser = PDB.PDBParser(QUIET=True)
    try:
        model = parser.get_structure(id, pdb_path)[0]
    except:
        return None

    parsed = {
        'id': entry['id'],
        'heavy': None,
        'heavy_seqmap': None,
        'light': None,
        'light_seqmap': None,
        'antigen': None,
        'antigen_seqmap': None,
    }
    try:
        if entry['H_chain'] is not None:
            (
                parsed['heavy'],
                parsed['heavy_seqmap']
            ) = _label_heavy_chain_cdr(*parsers.parse_biopython_structure_onlyV(
                model[entry['H_chain']],
                max_resseq=113  # Chothia, end of Heavy chain Fv
            ))

        if entry['L_chain'] is not None:
            (
                parsed['light'],
                parsed['light_seqmap']
            ) = _label_light_chain_cdr(*parsers.parse_biopython_structure_onlyV(
                model[entry['L_chain']],
                max_resseq=106  # Chothia, end of Light chain Fv
            ))

        if parsed['heavy'] is None or parsed['light'] is None:
            raise ValueError('Neither valid H-chain or L-chain is found.')

        if len(entry['ag_chains']) > 0:
            chains = [model[c] for c in entry['ag_chains']]
            (
                parsed['antigen'],
                parsed['antigen_seqmap']
            ) = parsers.parse_biopython_structure_onlyV(chains)

    except (
            PDBExceptions.PDBConstructionException,
            parsers.ParsingException,
            KeyError,
            ValueError,
    ) as e:
        logging.warning('[{}] {}: {}'.format(
            task['id'],
            e.__class__.__name__,
            str(e)
        ))
        return None
    try:
        new_parsed = {'id': parsed['id']}

        new_parsed['aa'] = torch.concat([parsed['heavy']['aa'], parsed['light']['aa']])  # 合并重链和轻链的序列

        new_parsed['cdr_flag'] = torch.concat([parsed['heavy']['cdr_flag'], parsed['light']['cdr_flag']])  # 合并重链和轻链的cdr区域标记

        is_heavy = torch.zeros_like(new_parsed['aa'])
        is_heavy[:len(parsed['heavy']['aa'])] = 1
        new_parsed['is_heavy'] = is_heavy  # 标记重链

        chain_id = torch.ones_like(new_parsed['aa'])
        chain_id[:len(parsed['heavy']['aa'])] = 0
        new_parsed['chain_id'] = chain_id  # 链id

        # 合并重链和轻链的残基编号，并进行跳变
        heavy_len = len(parsed['heavy']['res_nb'])
        light_len = len(parsed['light']['res_nb'])
        chain_jump = 50
        # 分别生成连续编号
        heavy_idx = torch.arange(1, heavy_len + 1)
        light_idx = torch.arange(1, light_len + 1) + heavy_idx[-1] + chain_jump
        res_nb = torch.cat([heavy_idx, light_idx], dim=0)
        new_parsed['res_nb'] = res_nb
        new_parsed['resseq'] = res_nb

        """合并重链和轻链后再进行平移，以避免破坏抗体的空间结构"""
        pos_heavyatom = torch.concat([parsed['heavy']['pos_heavyatom'], parsed['light']['pos_heavyatom']], dim=0)
        mask_heavyatom = torch.concat([parsed['heavy']['mask_heavyatom'], parsed['light']['mask_heavyatom']], dim=0)
        new_parsed['pos_heavyatom'] = center_and_realign_missing(pos_heavyatom, mask_heavyatom)[0] * mask_heavyatom[:, :, None]
        new_parsed['mask_heavyatom'] = mask_heavyatom
        # 添加psi角和侧链扭转角。需要注意psi角计算不能跨链，因此需要分开算
        heavy_torsion_angle, heavy_torsion_angle_mask = get_torsion_angle(parsed['heavy']['pos_heavyatom'], parsed['heavy']['aa'])
        light_torsion_angle, light_torsion_angle_mask = get_torsion_angle(parsed['light']['pos_heavyatom'], parsed['light']['aa'])
        new_parsed['torsion_angle'] = torch.concat([heavy_torsion_angle, light_torsion_angle])
        new_parsed['torsion_angle_mask'] = torch.concat([heavy_torsion_angle_mask, light_torsion_angle_mask])
        new_parsed['H1_seq'] = parsed['heavy']['H1_seq']
        new_parsed['H2_seq'] = parsed['heavy']['H2_seq']
        new_parsed['H3_seq'] = parsed['heavy']['H3_seq']
        new_parsed['L1_seq'] = parsed['light']['L1_seq']
        new_parsed['L2_seq'] = parsed['light']['L2_seq']
        new_parsed['L3_seq'] = parsed['light']['L3_seq']
    except Exception as e:
        print(e)
        exit(-1)
    return new_parsed


class SAbDabDataset(Dataset):
    MAP_SIZE = 32 * (1024 * 1024 * 1024)  # 32GB

    def __init__(
            self,
            summary_path='/srv/storage/hdd/zzl/lzj/dataset/SAbDab/SAbDab_summary_all_25_3_13.tsv',
            chothia_dir='/srv/storage/hdd/zzl/lzj/dataset/SAbDab/chothia',
            processed_dir='/srv/storage/hdd/zzl/lzj/dataset/SAbDab/processed_onlyV',
            embedding_h5_dir='/srv/storage/hdd/zzl/lzj/dataset/SAbDab/processed/SAbDab_processed_sequences_embedding.h5',
            split='val',
            split_seed=2025,
            transform=None,
            reset=False
    ):
        super().__init__()
        self.summary_path = summary_path
        self.chothia_dir = chothia_dir
        if not os.path.exists(chothia_dir):
            raise FileNotFoundError(
                f"SAbDab structures not found in {chothia_dir}. "
                "Please download them from http://opig.stats.ox.ac.uk/webapps/newsabdab/sabdab/archive/all/"
            )
        self.processed_dir = processed_dir
        self.split = split
        os.makedirs(processed_dir, exist_ok=True)

        self.sabdab_entries = None
        self._load_sabdab_entries()  # 加载抗体实体信息

        self.db_conn = None
        self.db_ids = None
        self._load_structures(reset)

        self.clusters = None
        self.id_to_cluster = None
        self._load_clusters(reset)

        self.ids_in_split = None
        self._load_split(split, split_seed)

        self.transform = transform

        self.embedding_h5 = h5py.File(embedding_h5_dir, 'r')

    def _load_sabdab_entries(self):
        df = pd.read_csv(self.summary_path, sep='\t')
        entries_all = []
        for i, row in tqdm(
                df.iterrows(),
                dynamic_ncols=True,
                desc='Loading entries',
                total=len(df),
        ):
            entry_id = "{pdbcode}_{H}_{L}_{Ag}".format(
                pdbcode=row['pdb'],
                H=nan_to_empty_string(row['Hchain']),
                L=nan_to_empty_string(row['Lchain']),
                Ag=''.join(split_sabdab_delimited_str(
                    nan_to_empty_string(row['antigen_chain'])
                ))
            )  # 决定数据的名字，格式为rcsb_id+重链+轻链+抗原链
            ag_chains = split_sabdab_delimited_str(
                nan_to_empty_string(row['antigen_chain'])
            )  # 获取抗原所属的链
            resolution = parse_sabdab_resolution(row['resolution'])  # 实验解析分辨率
            # 数据实体信息
            entry = {
                'id': entry_id,
                'pdbcode': row['pdb'],
                'H_chain': nan_to_none(row['Hchain']),
                'L_chain': nan_to_none(row['Lchain']),
                'ag_chains': ag_chains,
                'ag_type': nan_to_none(row['antigen_type']),
                'ag_name': nan_to_none(row['antigen_name']),
                'date': datetime.datetime.strptime(row['date'], '%m/%d/%y'),
                'resolution': resolution,
                'method': row['method'],
                'scfv': row['scfv'],
            }
            # 如果有抗原信息，则抗原必须在可接受的范围内。并且分辨率要小于指定的阈值
            if ((entry['ag_type'] in ALLOWED_AG_TYPES or entry['ag_type'] is None)
                    and (entry['resolution'] is not None and entry['resolution'] <= RESOLUTION_THRESHOLD)):
                entries_all.append(entry)
        self.sabdab_entries = entries_all

    def _load_structures(self, reset):
        # self._preprocess_structures()
        if not os.path.exists(self._structure_cache_path) or reset:
            if os.path.exists(self._structure_cache_path):
                os.unlink(self._structure_cache_path)
            self._preprocess_structures()

        with open(self._structure_cache_path + '-ids', 'rb') as f:
            self.db_ids = pickle.load(f)
        self.sabdab_entries = list(
            filter(
                lambda e: e['id'] in self.db_ids,
                self.sabdab_entries
            )
        )

    @property
    def _structure_cache_path(self):
        return os.path.join(self.processed_dir, 'structures.lmdb')

    def _close_db(self):
        if self.db_conn is not None:
            self.db_conn.close()
        self.db_conn = None
        self.db_ids = None

    def _preprocess_structures(self):
        tasks = []
        for entry in self.sabdab_entries:
            pdb_path = os.path.join(self.chothia_dir, '{}.pdb'.format(entry['pdbcode']))
            if not os.path.exists(pdb_path):
                logging.warning(f"PDB not found: {pdb_path}")
                continue
            tasks.append({
                'id': entry['id'],
                'entry': entry,
                'pdb_path': pdb_path,
            })
        # joblib.cpu_count()
        data_list = joblib.Parallel(n_jobs=max(joblib.cpu_count(), 1),
                                    )(
            joblib.delayed(preprocess_sabdab_structure)(task)
            for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess')
        )
        self._close_db()
        db_conn = lmdb.open(
            self._structure_cache_path,
            map_size=self.MAP_SIZE,
            create=True,
            subdir=False,
            readonly=False,
        )
        ids = []
        with db_conn.begin(write=True, buffers=True) as txn:
            for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'):
                if data is None:
                    continue
                ids.append(data['id'])
                txn.put(data['id'].encode('utf-8'), pickle.dumps(data))

        with open(self._structure_cache_path + '-ids', 'wb') as f:
            pickle.dump(ids, f)

    @property
    def _cluster_path(self):
        return os.path.join(self.processed_dir, 'cluster_result_cluster.tsv')

    def _load_clusters(self, reset):
        if not os.path.exists(self._cluster_path) or reset:
            self._create_clusters()

        clusters, id_to_cluster = {}, {}
        with open(self._cluster_path, 'r') as f:
            for line in f.readlines():
                cluster_name, data_id = line.split()
                if cluster_name not in clusters:
                    clusters[cluster_name] = []
                clusters[cluster_name].append(data_id)
                id_to_cluster[data_id] = cluster_name
        self.clusters = clusters
        self.id_to_cluster = id_to_cluster

    def _create_clusters(self):
        cdr_records = []
        for id in self.db_ids:
            structure = self.get_structure(id)
            cdr_records.append(SeqRecord.SeqRecord(
                Seq.Seq(structure['H3_seq']),
                id=structure['id'],
                name='',
                description='',
            ))
            cdr_records.append(SeqRecord.SeqRecord(
                Seq.Seq(structure['L3_seq']),
                id=structure['id'],
                name='',
                description='',
            ))
        fasta_path = os.path.join(self.processed_dir, 'cdr_sequences.fasta')
        SeqIO.write(cdr_records, fasta_path, 'fasta')

        cmd = ' '.join([
            'mmseqs', 'easy-cluster',
            os.path.realpath(fasta_path),
            'cluster_result', 'cluster_tmp',
            '--min-seq-id', '0.5',
            '-c', '0.8',
            '--cov-mode', '1',
        ])
        subprocess.run(cmd, cwd=self.processed_dir, shell=True, check=True)

    def _load_split(self, split, split_seed):
        assert split in ('train', 'val', 'test')
        ids_test = [
            entry['id']
            for entry in self.sabdab_entries
            if entry['ag_name'] in TEST_ANTIGENS
        ]
        test_relevant_clusters = set([self.id_to_cluster[id] for id in ids_test])

        ids = pickle.load(open("/srv/storage/hdd/xxx/dataset/SAbDab/processed_onlyV/onlyV_drop_duplicates_id.pkl", 'rb'))
        ids_train_val = [
            pdb_id
            for pdb_id in ids if pdb_id not in test_relevant_clusters
        ]

        # random.Random(split_seed).shuffle(ids_train_val)
        if split == 'test':
            self.ids_in_split = ids_test
        elif split == 'val':
            self.ids_in_split = ids_train_val[:20]
        else:
            self.ids_in_split = ids_train_val[20:]

    def _connect_db(self):
        if self.db_conn is not None:
            return
        self.db_conn = lmdb.open(
            self._structure_cache_path,
            map_size=self.MAP_SIZE,
            create=False,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

    def get_structure(self, id):
        self._close_db()
        self._connect_db()
        with self.db_conn.begin() as txn:
            return pickle.loads(txn.get(id.encode()))

    def __len__(self):
        return len(self.ids_in_split)

    def __getitem__(self, index):
        id = self.ids_in_split[index]
        data = self.get_structure(id)
        if self.transform is not None:
            data = self.transform(data)
        for key in data.keys():
            if data[key] is None:
                data[key] = {}
        data['antibody_emb'] = torch.from_numpy(self.embedding_h5[id][:])

        return data


# @register_dataset('sabdab')
# def get_sabdab_dataset(cfg, transform):
#     return SAbDabDataset(
#         summary_path = cfg.summary_path,
#         chothia_dir = cfg.chothia_dir,
#         processed_dir = cfg.processed_dir,
#         split = cfg.split,
#         split_seed = cfg.get('split_seed', 2022),
#         transform = transform,
#     )


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, default='val')
    parser.add_argument('--processed_dir', type=str, default='./data/processed')
    parser.add_argument('--reset', action='store_true', default=False)
    args = parser.parse_args()
    if args.reset:
        sure = input('Sure to reset? (y/n): ')
        if sure != 'y':
            exit()
    dataset = SAbDabDataset(
        split=args.split,
        reset=args.reset
    )
    print(dataset[0])
    print(len(dataset), len(dataset.clusters))
