import random
import pandas as pd
from typing import List, Optional, Tuple

import functools
import math
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm.auto import tqdm
from paretoset import paretoset

from disallow_tracker import DisallowTracker
from reagent import Reagent
from ts_logger import get_logger
from ts_utils import read_reagents
from evaluators import DBEvaluator


class ThompsonSampler:
    def __init__(self, mode="maximize", log_filename: Optional[str] = None):
        """
        Basic init
        :param mode: maximize or minimize
        :param log_filename: Optional filename to write logging to. If None, logging will be output to stdout
        """
        # A list of lists of Reagents. Each component in the reaction will have one list of Reagents in this list
        self.reagent_lists: List[List[Reagent]] = []
        self.reaction = None
        self.evaluator = None
        self.num_prods = 0
        self.logger = get_logger(__name__, filename=log_filename)
        self._disallow_tracker = None
        self.hide_progress = False
        self._post_log_file_stem = None
        self._log_posteriors = False
        self._mode = mode
        if self._mode == "maximize":
            self.pick_function = lambda s: np.nanargmax([s[i][0] for i in range(s.shape[0])])
            self._top_func = lambda s: max(s, key=lambda x: x[0])
        elif self._mode == "minimize":
            self.pick_function = lambda s: np.nanargmin([s[i][0] for i in range(s.shape[0])])
            self._top_func = lambda s: min(s, key=lambda x: x[0])
        elif self._mode == "maximize_boltzmann":
            # See documentation for _boltzmann_reweighted_pick
            self.pick_function = functools.partial(self._boltzmann_reweighted_pick)
            self._top_func = lambda s: max(s, key=lambda x: x[0])
        elif self._mode == "minimize_boltzmann":
            # See documentation for _boltzmann_reweighted_pick
            self.pick_function = functools.partial(self._boltzmann_reweighted_pick)
            self._top_func = lambda s: min(s, key=lambda x: x[0])
        elif self._mode == "mo_maximize_TTPFTS":
            self.pick_function = self._mo_maximize_TTPFTS
            self._top_func = None
        elif self._mode == "mo_maximize_TS":
            self.pick_function = lambda s: self._mo_maximize_TTPFTS(s, rho=1)
            self._top_func = None
        else:
            raise ValueError(f"{mode} is not a supported argument")
        self._warmup_std = None

        self.pareto_history = []
        self._pareto_log_enabled: bool = False
        self.pareto_log_filename: Optional[str] = None

    def enable_pareto_logging(self, filename: Optional[str] = None) -> None:
        """
        Enable logging of pareto-optimal reagent indices for each bandit during MO runs.
        If filename is provided, each logged record will be appended as a JSON line to that file.
        """
        self._pareto_log_enabled = True
        self.pareto_log_filename = filename

    def get_pareto_history(self) -> list:
        """Return the in-memory pareto history and possibly write it to a parquet file."""
        pareto_history_df = pd.DataFrame(self.pareto_history)
        if self.pareto_log_filename is not None:
            pareto_history_df.to_parquet(self.pareto_log_filename)
        return pareto_history_df

    def log_bandit_pareto_indices(self, iteration, bandit, choice_row) -> None:
        pareto_mask = paretoset(choice_row, sense=["max"] * choice_row.shape[1], distinct=False)
        pareto_indices = np.where(pareto_mask)[0].tolist()
        self.pareto_history.append({
            "iteration": iteration,
            "bandit": bandit,
            "pareto_indices": pareto_indices
        })

    def _boltzmann_reweighted_pick(self, scores: np.ndarray):
        """Rather than choosing the top sampled score, use a reweighted probability.

        Zhao, H., Nittinger, E. & Tyrchan, C. Enhanced Thompson Sampling by Roulette
        Wheel Selection for Screening Ultra-Large Combinatorial Libraries.
        bioRxiv 2024.05.16.594622 (2024) doi:10.1101/2024.05.16.594622
        suggested several modifications to the Thompson Sampling procedure.
        This method implements one of those, namely a Boltzmann style probability distribution
        from the sampled values. The reagent is chosen based on that distribution rather than
        simply the max sample.
        """
        if self._mode == "minimize_boltzmann":
            scores = -scores
        exp_terms = np.exp(scores / self._warmup_std)
        probs = exp_terms / np.nansum(exp_terms)
        probs[np.isnan(probs)] = [0.0]
        return np.random.choice(probs.shape[0], p=[probs[i][0] for i in range(probs.shape[0])])

    def _mo_maximize_TTPFTS(self, scores: np.ndarray, rho=0.5) -> int:
        """
        Top-Two Pareto Fronts Thomspon Sampling (TTPFTS) strategy for MO maximization.
        Can also be used with rho=1 to act as standard Pareto Thompson Sampling for MO maximization.
        """
        pareto_mask = paretoset(scores, sense=["max"] * scores.shape[1])
        pareto_indices = np.where(pareto_mask)[0]
        if np.random.random() < rho:
            # Return a random pareto index
            return np.random.choice(pareto_indices)
        else:
            non_pareto_indices = np.where(~pareto_mask)[0]
            suboptimal_pareto_mask = paretoset(scores[non_pareto_indices, :], sense=["max"] * scores.shape[1])
            suboptimal_pareto_indices = non_pareto_indices[np.where(suboptimal_pareto_mask)[0]]
            if len(suboptimal_pareto_indices) > 0:
                return np.random.choice(suboptimal_pareto_indices)
            else:
                return np.random.choice(pareto_indices)

    def set_hide_progress(self, hide_progress: bool) -> None:
        """
        Hide the progress bars
        :param hide_progress: set to True to hide the progress baars
        """
        self.hide_progress = hide_progress

    def read_reagents(self, reagent_file_list, num_to_select: Optional[int] = None, num_objs: int = 1) -> None:
        """
        Reads the reagents from reagent_file_list
        :param reagent_file_list: List of reagent filepaths
        :param num_to_select: Max number of reagents to select from the reagents file (for dev purposes only)
        :param num_objs: number of objectives being optimized
        :return: None
        """
        self.reagent_lists = read_reagents(reagent_file_list, num_to_select, num_objs)
        self.num_prods = math.prod([len(x) for x in self.reagent_lists])
        self.logger.info(f"{self.num_prods:.2e} possible products")
        self._disallow_tracker = DisallowTracker([len(x) for x in self.reagent_lists])

    def get_num_prods(self) -> int:
        """
        Get the total number of possible products
        :return: num_prods
        """
        return self.num_prods

    def set_evaluator(self, evaluator):
        """
        Define the evaluator
        :param evaluator: evaluator class, must define an evaluate method that takes an RDKit molecule
        """
        if self._mode in ["maximize", "minimize", "maximize_boltzmann",
                          "minimize_boltzmann"] and evaluator.num_objs != 1:
            raise ValueError(f"Evaluator must have num_objs=1 for mode {self._mode}, found {evaluator.num_objs}")
        if self._mode == "mo_maximize" and evaluator.num_objs < 2:
            raise ValueError(f"Evaluator must have num_objs>1 for mode {self._mode}, found {evaluator.num_objs}")
        self.evaluator = evaluator

    def set_reaction(self, rxn_smarts):
        """
        Define the reaction
        :param rxn_smarts: reaction SMARTS
        """
        self.reaction = AllChem.ReactionFromSmarts(rxn_smarts)

    def evaluate(self, choice_list: List[int]) -> Tuple[str, str, list[float]]:
        """Evaluate a set of reagents
        :param choice_list: list of reagent ids
        :return: smiles for the reaction product, score for the reaction product
        """
        selected_reagents = []
        for idx, choice in enumerate(choice_list):
            component_reagent_list = self.reagent_lists[idx]
            selected_reagents.append(component_reagent_list[choice])
        prod = self.reaction.RunReactants([reagent.mol for reagent in selected_reagents])
        product_name = "_".join([reagent.reagent_name for reagent in selected_reagents])
        res = np.nan
        product_smiles = "FAIL"
        if prod:
            prod_mol = prod[0][0]  # RunReactants returns Tuple[Tuple[Mol]]
            Chem.SanitizeMol(prod_mol)
            product_smiles = Chem.MolToSmiles(prod_mol)
            if isinstance(self.evaluator, DBEvaluator):
                res = self.evaluator.evaluate(product_name)
                res = float(res)
            else:
                res = self.evaluator.evaluate(prod_mol)
            if np.all(np.isfinite(res)):
                [reagent.add_score(res) for reagent in selected_reagents]
        return product_smiles, product_name, res

    def warm_up(self, num_warmup_trials=3):
        """Warm-up phase, each reagent is sampled with num_warmup_trials random partners
        :param num_warmup_trials: number of times to sample each reagent
        """
        # get the list of reagent indices
        idx_list = list(range(0, len(self.reagent_lists)))
        # get the number of reagents for each component in the reaction
        reagent_count_list = [len(x) for x in self.reagent_lists]
        warmup_results = []
        for i in idx_list:
            partner_list = [x for x in idx_list if x != i]
            # The number of reagents for this component
            current_max = reagent_count_list[i]
            # For each reagent...
            for j in tqdm(range(0, current_max), desc=f"Warmup {i + 1} of {len(idx_list)} (trials={num_warmup_trials})",
                          disable=self.hide_progress):
                # For each warmup trial...
                for k in range(0, num_warmup_trials):
                    current_list = [DisallowTracker.Empty] * len(idx_list)
                    current_list[i] = DisallowTracker.To_Fill
                    disallow_mask = self._disallow_tracker.get_disallowed_selection_mask(current_list)
                    if j not in disallow_mask:
                        ## ok we can select this reagent
                        current_list[i] = j
                        # Randomly select reagents for each additional component of the reaction
                        for p in partner_list:
                            # tell the disallow tracker which site we are filling
                            current_list[p] = DisallowTracker.To_Fill
                            # get the new disallow mask
                            disallow_mask = self._disallow_tracker.get_disallowed_selection_mask(current_list)
                            selection_scores = np.random.uniform(size=reagent_count_list[p])
                            # null out the disallowed ones
                            selection_scores[list(disallow_mask)] = np.nan
                            # and select a random one
                            current_list[p] = np.nanargmax(selection_scores).item(0)
                        self._disallow_tracker.update(current_list)
                        product_smiles, product_name, score = self.evaluate(current_list)
                        if np.all(np.isfinite(score)):
                            warmup_results.append([score, product_smiles, product_name])
        warmup_scores = [ws[0] for ws in warmup_results]
        # self.logger.info(
        #     f"warmup score stats: "
        #     f"cnt={len(warmup_scores)}, "
        #     f"mean={np.mean(warmup_scores):0.4f}, "
        #     f"std={np.std(warmup_scores):0.4f}, "
        #     f"min={np.min(warmup_scores):0.4f}, "
        #     f"max={np.max(warmup_scores):0.4f}")
        # initialize each reagent
        prior_mean = np.mean(warmup_scores, axis=0)
        prior_std = np.std(warmup_scores, axis=0)
        self._warmup_std = prior_std
        for i in range(0, len(self.reagent_lists)):
            for j in range(0, len(self.reagent_lists[i])):
                reagent = self.reagent_lists[i][j]
                try:
                    reagent.init_given_prior(prior_mean=prior_mean, prior_std=prior_std)
                except ValueError:
                    self.logger.info(
                        f"Skipping reagent {reagent.reagent_name} because there were no successful evaluations during warmup")
                    self._disallow_tracker.retire_one_synthon(i, j)
        if self._top_func is not None:
            top_score, _, _ = self._top_func(warmup_scores)
            self.logger.info(f"Top score found during warmup: {top_score:0.4f}")
        return warmup_results

    def enable_log_posteriors(self, file_stem: str):
        self._post_log_file_stem = file_stem
        self._log_posteriors = True

    def _log_reagent_posteriors(self, step: int):
        for i, reagent_list in enumerate(self.reagent_lists):
            df = pd.DataFrame({
                "reagent_name": [r.reagent_name for r in reagent_list],
                "means": [r.current_mean.tolist() for r in reagent_list],
                "stds": [r.current_std.tolist() for r in reagent_list],
            })
            df.to_parquet(f"{self._post_log_file_stem}_t{step}_component{i}.parquet", index=False)

    def search(self, num_cycles=25):
        """Run the search
        :param: num_cycles: number of search iterations
        :return: a list of SMILES and scores
        """
        out_list = []
        rng = np.random.default_rng()
        for i in tqdm(range(0, num_cycles), desc="Cycle", disable=self.hide_progress):
            selected_reagents = [DisallowTracker.Empty] * len(self.reagent_lists)
            for cycle_id in random.sample(range(0, len(self.reagent_lists)), len(self.reagent_lists)):
                reagent_list = self.reagent_lists[cycle_id]
                selected_reagents[cycle_id] = DisallowTracker.To_Fill
                disallow_mask = self._disallow_tracker.get_disallowed_selection_mask(selected_reagents)
                stds = np.array([r.current_std for r in reagent_list])
                mu = np.array([r.current_mean for r in reagent_list])
                choice_row = rng.normal(size=(len(reagent_list), self.evaluator.num_objs)) * stds + mu
                if disallow_mask:
                    choice_row[np.array(list(disallow_mask))] = [np.nan] * self.evaluator.num_objs

                if self._pareto_log_enabled and self._mode in ["mo_maximize_TTPFTS", "mo_maximize_TS"]:
                    self.log_bandit_pareto_indices(i, cycle_id, choice_row)

                selected_reagents[cycle_id] = self.pick_function(choice_row)
            self._disallow_tracker.update(selected_reagents)
            # Select a reagent for each component, according to the choice function
            smiles, name, score = self.evaluate(selected_reagents)
            if np.all(np.isfinite(score)):
                out_list.append([smiles, name, score])
            if i % 100 == 0 and self._top_func is not None:
                top_score, top_smiles, top_name = self._top_func(out_list)
                self.logger.info(f"Iteration: {i} max score: {top_score[0]:2f} smiles: {top_smiles} {top_name}")
            if self._log_posteriors and i % 500 == 0:
                self._log_reagent_posteriors(step=i)
        return out_list
