import argparse
import sys
import os
import warnings
import tempfile
import pandas as pd

from Bio.PDB import PDBParser
from pathlib import Path
from rdkit import Chem
from torch.utils.data import DataLoader
from functools import partial

basedir = Path(__file__).resolve().parent.parent
sys.path.append(str(basedir))
warnings.filterwarnings("ignore")

from src import utils
from src.analysis.visualization_utils import mols_to_pdbfile
from src.data.dataset import ProcessedLigandPocketDataset
from src.data.data_utils import TensorDict, process_raw_pair
from src.model.lightning import DrugFlow
from src.size_predictor.size_model import SizeModel
from src.sbdd_metrics.metrics import FullEvaluator

from tqdm import tqdm
from pdb import set_trace


def aggregate_metrics(table):
    agg_col = 'posebusters'
    total = 0
    table[agg_col] = 0
    for column in table.columns:
        if column.startswith(agg_col) and column != agg_col:
            table[agg_col] += table[column].fillna(0).astype(float)
            total += 1
    table[agg_col] = table[agg_col] / total

    agg_col = 'reos'
    total = 0
    table[agg_col] = 0
    for column in table.columns:
        if column.startswith(agg_col) and column != agg_col:
            table[agg_col] += table[column].fillna(0).astype(float)
            total += 1
    table[agg_col] = table[agg_col] / total

    agg_col = 'chembl_ring_systems'
    total = 0
    table[agg_col] = 0
    for column in table.columns:
        if column.startswith(agg_col) and column != agg_col and not column.endswith('smi'):
            table[agg_col] += table[column].fillna(0).astype(float)
            total += 1
    table[agg_col] = table[agg_col] / total
    return table


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument('--protein', type=str, required=True)
    p.add_argument('--ligand', type=str, required=True)
    p.add_argument('--checkpoint', type=str, required=True)
    p.add_argument('--molecule_size', type=str, required=False, default=None)
    p.add_argument('--output', type=str, required=False, default='samples.sdf')
    p.add_argument('--n_samples', type=int, required=False, default=10)
    p.add_argument('--batch_size', type=int, required=False, default=64)
    p.add_argument('--pocket_distance_cutoff', type=float, required=False, default=8.0)
    p.add_argument('--n_steps', type=int, required=False, default=None)
    p.add_argument('--device', type=str, required=False, default='cuda:0')
    p.add_argument('--datadir', type=str, required=False, default=None)
    p.add_argument('--seed', type=int, required=False, default=42)
    p.add_argument('--filter', action='store_true', required=False, default=False)
    p.add_argument('--backoff', action='store_true', required=False, default=False)
    args = p.parse_args()

    utils.set_deterministic(seed=args.seed)
    utils.disable_rdkit_logging()

    # Loading model
    chkpt_path = Path(args.checkpoint)
    chkpt_name = chkpt_path.parts[-1].split('.')[0]
    model = DrugFlow.load_from_checkpoint(args.checkpoint, map_location=args.device, strict=False)
    if args.datadir is not None:
        model.datadir = args.datadir

    model.setup(stage='generation')
    model.batch_size = model.eval_batch_size = args.batch_size
    model.eval().to(args.device)
    if args.n_steps is not None:
        model.T = args.n_steps

    # Loading size model
    size_model = None
    molecule_size = None
    molecule_size_boundaries = None
    if args.molecule_size is not None: 
        if args.molecule_size.isdigit():
            molecule_size = int(args.molecule_size)
            print(f'Will generate molecules of size {molecule_size}')
        else:
            boundaries = [x.strip() for x in args.molecule_size.split(',')]
            if len(boundaries) == 2 and boundaries[0].isdigit() and boundaries[1].isdigit():
                left = int(boundaries[0])
                right = int(boundaries[1])
                molecule_size = f"uniform_{left}_{right}"
                print(f'Will generate linkers with numbers of atoms sampled from U({left}, {right})')
            else:
                molecule_size = "nn_prediction"
                size_model = SizeModel.load_from_checkpoint(args.size_checkpoint, map_location=args.device)
                size_model.batch_size = size_model.eval_batch_size = args.batch_size
                size_model.eval().to(args.device)
                print(f'Loaded size prediction model {args.size_checkpoint}')

    # Preparing input
    pdb_model = PDBParser(QUIET=True).get_structure('', args.protein)[0]
    rdmol = Chem.SDMolSupplier(str(args.ligand))[0]

    ligand, pocket = process_raw_pair(
        pdb_model, rdmol,
        dist_cutoff=args.pocket_distance_cutoff,
        pocket_representation=model.pocket_representation,
        compute_nerf_params=True,
        nma_input=args.protein if model.dynamics.add_nma_feat else None
    )
    ligand['name'] = 'ligand'
    dataset = [{'ligand': ligand, 'pocket': pocket} for _ in range(args.batch_size)]
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=args.batch_size, 
        collate_fn=partial(ProcessedLigandPocketDataset.collate_fn, ligand_transform=None),
        pin_memory=True
    )

    smiles = set()
    if args.backoff and os.path.exists(args.output):
        supp = Chem.SDMolSupplier(str(args.output))
        for mol in supp:
            try:
                smiles.add(Chem.MolToSmiles(mol))
            except:
                continue
        already_generated = len(smiles)
        n_samples = min(args.n_samples - already_generated, 0)
        w = Chem.SDWriter(open(args.output, 'a'))
        print(f'Already generated {already_generated} molecules. Will generate {n_samples}')
    else:
        Path(args.output).parent.absolute().mkdir(parents=True, exist_ok=True)
        n_samples = args.n_samples
        w = Chem.SDWriter(args.output)
        print(f'Will generate {n_samples}')

    evaluator = FullEvaluator()

    with tqdm(total=n_samples) as pbar:
        while len(smiles) < n_samples:
            for i, data in enumerate(dataloader):
                new_data = {
                    'ligand': TensorDict(**data['ligand']).to(args.device),
                    'pocket': TensorDict(**data['pocket']).to(args.device),
                }
                rdmols, rdpockets, _ = model.sample(
                    new_data,
                    n_samples=1,
                    timesteps=args.n_steps,
                    num_nodes=molecule_size,
                    size_model=size_model,
                )

                added_molecules = 0
                if not args.filter:
                    for mol in rdmols:
                        w.write(mol)
                        added_molecules += 1
                else:
                    results = []
                    with tempfile.TemporaryDirectory() as tmpdir:
                        for mol, receptor in zip(rdmols, rdpockets):
                            receptor_path = Path(tmpdir, 'receptor.pdb')
                            Chem.MolToPDBFile(receptor, str(receptor_path))
                            results.append(evaluator(mol, receptor_path))

                    table = pd.DataFrame(results)
                    table['novel'] = ~table['representation.smiles'].isin(smiles)
                    table = aggregate_metrics(table)
                    table['passed_filters'] = (
                        (table['posebusters'] == 1) &
                        (table['reos'] == 1) &
                        (table['chembl_ring_systems'] == 1) &
                        (table['novel'] == 1)
                    )
                    for i, (passed, smi) in enumerate(table[['passed_filters', 'representation.smiles']].values):
                        if passed:
                            w.write(rdmols[i])
                            smiles.add(smi)
                            added_molecules += 1

                pbar.update(added_molecules)
