import hydra
import numpy as np
import pandas as pd
import torch
from botorch.utils.multi_objective import pareto

from lambo.candidate import FoldedCandidate
from torch_seq_moo.tasks.base_task import BaseTask

from typing import Optional, Dict

from Bio.PDB import PDBParser, SASA


class SurfaceArea:
    def __init__(self, probe_radius: float = 1.4, n_points: int = 100, radii_dict: Optional[Dict] = None):
        """
            Computes solvent accessible surface area (SASA) via calling the implementation of the 
            Shrake Rupley algorithm in Biopython

            Args (copied over from biopython docs):
                probe_radius (float) – radius of the probe in A. Default is 1.40, roughly the radius of a water molecule.
                n_points (int) – resolution of the surface of each atom. Default is 100. 
                    A higher number of points results in more precise measurements, but slows down the calculation.
                radii_dict (dict) – user-provided dictionary of atomic radii to use in the calculation. 
                    Values will replace/complement those in the default ATOMIC_RADII dictionary.

            TODO: assess these arguments and work out api a bit better

        """
        # this is a workaround for a bug in BioPython
        # see https://github.com/biopython/biopython/pull/3777
        if radii_dict is None:
            radii_dict = {'X': 2.0}

        self.parser = PDBParser(QUIET=1)
        self.structure_computer = SASA.ShrakeRupley(probe_radius=probe_radius, n_points=n_points, radii_dict=radii_dict)

    def __call__(self, name, loc) -> float:
        struct = self.parser.get_structure(name, loc)
        self.structure_computer.compute(struct, level="S")
        return struct.sasa



class ProxyRFPTask(BaseTask):
    def __init__(self, tokenizer, candidate_pool, obj_dim, transform=lambda x: x,
                 num_start_examples=1024, **kwargs):
        super().__init__(tokenizer, candidate_pool, obj_dim, transform, **kwargs)
        self.op_types = ["sub"]
        self.num_start_examples = num_start_examples

    def task_setup(self, config, project_root=None, *args, **kwargs):
        project_root = hydra.utils.get_original_cwd() if project_root is None else project_root
        work_dir = f'{project_root}/{config.log_dir}/{config.job_name}/{config.timestamp}/foldx'
        rfp_known_structures = pd.read_csv(
            f'{project_root}/torch_seq_moo/assets/fpbase/rfp_known_structures.csv'
        )

        all_seqs = rfp_known_structures.foldx_seq.values
        all_targets = np.stack([
            -rfp_known_structures.SASA.values,
            rfp_known_structures.foldx_total_energy.values,
        ], axis=-1)

        seed_data = pd.read_csv(
            f'{project_root}/torch_seq_moo/assets/fpbase/proxy_rfp_seed_data.csv'
        )
        seed_data = seed_data.sample(self.num_start_examples)
        sample_batch_targets = np.stack([
            -seed_data.SASA.values,
            -seed_data.stability.values,
        ], axis=-1)

        all_seqs = np.concatenate((all_seqs, seed_data.foldx_seq.values))
        all_targets = np.concatenate((all_targets, sample_batch_targets))

        seq_len_mask = np.array([len(x) <= self.max_len for x in all_seqs])
        all_seqs = all_seqs[seq_len_mask]
        all_targets = all_targets[seq_len_mask]

        # filter candidate sequences by length
        foldx_seq_len = rfp_known_structures.foldx_seq.apply(lambda x: len(x))
        rfp_known_structures = rfp_known_structures[foldx_seq_len <= self.max_len]
        rfp_known_structures.reset_index(inplace=True)

        # find valid, non-dominated starting candidates
        valid_targets = np.stack([
            -rfp_known_structures.SASA.values,
            rfp_known_structures.foldx_total_energy.values,
        ], axis=-1)
        pareto_mask = pareto.is_non_dominated(-torch.tensor(valid_targets))
        base_targets = valid_targets[pareto_mask]

        base_candidates = []
        for row_idx, datum in rfp_known_structures.iterrows():
            if not pareto_mask[row_idx]:
                continue
            print(f'{datum.Name} is non-dominated, adding to start pool')
            pdb_id = datum.pdb_id.lower()
            chain_id = datum.longest_chain
            parent_pdb_path = f'{project_root}/torch_seq_moo/assets/foldx/{pdb_id}_{chain_id}/wt_input_Repair.pdb'
            base_candidates.append(
                FoldedCandidate(work_dir, parent_pdb_path, [], self.tokenizer,
                                skip_minimization=True, chain=chain_id, wild_name=datum.Name)
            )
        base_candidates = np.array(base_candidates).reshape(-1)

        return base_candidates, base_targets, all_seqs, all_targets

    def make_new_candidates(self, base_candidates, new_seqs):
        assert base_candidates.shape[0] == new_seqs.shape[0]
        new_candidates = []
        for b_cand, n_seq in zip(base_candidates, new_seqs):
            b_seq = b_cand.mutant_residue_seq
            assert len(b_seq) == len(n_seq), 'FoldX only accepts substitutions'
            mutation_ops = []
            for i, (b_char, n_char) in enumerate(zip(b_seq, n_seq)):
                if not b_char == n_char:
                    mutation_ops.append(b_cand.new_mutation(i, n_char, 'sub'))
            # mutation_ops = mutation_list(b_cand.mutant_residue_seq, n_seq)
            new_candidates.append(b_cand.new_candidate(mutation_ops))
        return np.stack(new_candidates)

    def _evaluate(self, x, out, *args, **kwargs):
        assert x.ndim == 2
        x_cands, x_seqs, f_vals = [], [], []
        for query_pt in x:
            cand_idx, mut_pos, mut_res_idx, _ = query_pt
            base_candidate = self.candidate_pool[cand_idx]
            mut_res = self.tokenizer.sampling_vocab[mut_res_idx]
            mut_list = [base_candidate.new_mutation(mut_pos, mut_res, mutation_type='sub')]
            candidate = base_candidate.new_candidate(mut_list)
            x_cands.append(candidate)
            x_seqs.append(candidate.mutant_residue_seq)
        x_seqs = np.array(x_seqs).reshape(-1)
        x_cands = np.array(x_cands).reshape(-1)

        out["X_cand"] = x_cands
        out["X_seq"] = x_seqs
        norm_scores = self.transform(self.score(x_cands))
        out["F"] = norm_scores

    def score(self, candidates):
        f_vals = []
        for cand in candidates:
            f_vals.append(np.array([
                -cand.mutant_surface_area,
                cand.mutant_total_energy,
            ]))
        return np.stack(f_vals)