import json
from collections import defaultdict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Set, Tuple

import gin
import numpy as np
import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, Descriptors, Draw, Lipinski
from rdkit.Chem.QED import qed
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles

from rgfn.api.trajectories import Trajectories
from rgfn.gfns.reaction_gfn.api.data_structures import Molecule
from rgfn.gfns.reaction_gfn.api.reaction_api import (
    ReactionAction0,
    ReactionActionC,
    ReactionState0Invalid,
    ReactionStateEarlyTerminal,
    ReactionStateTerminal,
)
from rgfn.shared.proxies.cached_proxy import CachedProxyBase
from rgfn.trainer.metrics.metric_base import MetricsBase, MetricsList


@gin.configurable()
class QED(MetricsBase):
    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        terminal_states = trajectories.get_last_states_flat()
        qed_scores_list = []
        for state in terminal_states:
            if isinstance(state, ReactionStateTerminal):
                qed_score = qed(state.molecule.rdkit_mol)
                qed_scores_list.append(qed_score)
        return {"qed": np.mean(qed_scores_list)}


@gin.configurable()
class PathCostProxy:
    def __init__(self, path: str):
        self.fragment_to_cost = json.load(open(path))
        self.fragment_to_cost = {k: float(v) for k, v in self.fragment_to_cost.items()}
        self.forward_terminal_to_cost: Dict[str, float] = defaultdict(lambda: float("inf"))
        self.replay_terminal_to_cost: Dict[str, float] = defaultdict(lambda: float("inf"))

    def _compute_costs(self, trajectories: Trajectories) -> List[float]:
        path_costs = []
        for actions, states in zip(trajectories._actions_list, trajectories._states_list):
            current_cost = (
                self.get_action_cost(actions[0])
                if not isinstance(states[0], ReactionState0Invalid)
                else float("inf")
            )
            for action, state in zip(actions[1:], states[2:]):
                if isinstance(action, ReactionActionC):
                    fragment_cost = self.get_action_cost(action)
                    yield_value = self.compute_yield(action)
                    current_cost = (current_cost + fragment_cost) * yield_value**-1

            path_costs.append(current_cost)
        return path_costs

    def compute_yield(self, action: ReactionActionC) -> float:
        return 0.75

    def get_fragment_cost(self, fragment: Molecule) -> float:
        return self.fragment_to_cost[fragment.smiles]

    def get_action_cost(self, action: ReactionActionC | ReactionAction0) -> float:
        if isinstance(action, ReactionAction0):
            return self.fragment_to_cost[action.fragment.smiles]
        else:
            return sum(
                self.fragment_to_cost[fragment.smiles] for fragment in action.input_fragments
            )


@gin.configurable()
class SaveSynthesisPaths(MetricsBase):
    def __init__(self, run_dir: str, n_forward: int, file_name: str = "paths.csv"):
        super().__init__()
        self.path = Path(run_dir) / file_name
        with open(self.path, "w") as f:
            f.write("iteration,path,proxy\n")
        self.trajectories_counter = 0
        self.iterations_counter = 0
        self.n_forward = n_forward

    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        forward_trajectories_mask = np.array(
            [True] * self.n_forward + [False] * (len(trajectories) - self.n_forward)
        )
        forward_trajectories = trajectories.masked_select(forward_trajectories_mask)
        proxy_scores = forward_trajectories.get_reward_outputs().proxy
        states_grouped = forward_trajectories._states_list
        actions_grouped = forward_trajectories._actions_list
        to_be_added = []
        for i, (states, actions, score) in enumerate(
            zip(states_grouped, actions_grouped, proxy_scores)
        ):
            current_trajectory = [actions[0].fragment.smiles]
            for action, state in zip(actions[1:], states[2:]):
                if isinstance(action, ReactionActionC):
                    if len(action.input_fragments) == 0:
                        value = (action.input_reaction.reaction, None)
                    else:
                        value = (action.input_reaction.reaction,) + tuple(
                            f.smiles for f in action.input_fragments
                        )
                    current_trajectory.append(value)
                    assert state.molecule == action.output_molecule
                    current_trajectory.append(state.molecule.smiles)
            score = score.item() if isinstance(states[-1], ReactionStateTerminal) else 0.0
            to_be_added.append((self.trajectories_counter + i, current_trajectory, score))

        with open(self.path, "a") as f:
            for iteration, path, proxy in to_be_added:
                f.write(f'{iteration},"{path}",{proxy}\n')

        self.trajectories_counter += len(trajectories)
        self.iterations_counter += 1
        return {"num_visited_molecules": self.trajectories_counter}


@gin.configurable()
class ScaffoldCost(MetricsBase):
    def __init__(
        self,
        path_cost_proxy: PathCostProxy,
        threshold: float,
        proxy_component_name: str | None = None,
        n_cheapest_list: List[int] = (100,),
        forward_only: bool = False,
    ):
        super().__init__()
        self.threshold = threshold
        self.path_cost_proxy = path_cost_proxy
        self.proxy_component_name = proxy_component_name
        self.scaffold_to_mean_cost: Dict[str, float] = defaultdict(lambda: 0)
        self.scaffold_to_count: Dict[str, int] = defaultdict(lambda: 0)
        self.scaffold_to_min_cost: Dict[str, float] = defaultdict(lambda: float("inf"))
        self.n_cheapest_list = n_cheapest_list
        self.forward_only = forward_only

    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        terminal_states = trajectories.get_last_states_flat()
        reward_outputs = trajectories.get_reward_outputs()
        proxy_values = (
            reward_outputs.proxy
            if self.proxy_component_name is None
            else reward_outputs.proxy_components[self.proxy_component_name]
        )
        costs = self.path_cost_proxy._compute_costs(trajectories)
        for state, proxy, cost in zip(terminal_states, proxy_values, costs):
            if proxy.item() > self.threshold:
                scaffold = MurckoScaffoldSmiles(state.molecule.smiles)
                current_mean = self.scaffold_to_mean_cost[scaffold]
                current_count = self.scaffold_to_count[scaffold]
                new_mean = current_mean * (current_count / (current_count + 1)) + cost / (
                    current_count + 1
                )
                self.scaffold_to_mean_cost[scaffold] = new_mean
                self.scaffold_to_min_cost[scaffold] = min(self.scaffold_to_min_cost[scaffold], cost)
                self.scaffold_to_count[scaffold] += 1

        results = {}
        suffix = "_forward" if self.forward_only else ""
        mean_sorted_values = sorted(self.scaffold_to_mean_cost.values())
        min_sorted_values = sorted(self.scaffold_to_min_cost.values())
        for n in self.n_cheapest_list:
            results[f"cost_{n}_cheapest_mean_{self.threshold}{suffix}"] = np.mean(
                mean_sorted_values[:n]
            )
            results[f"cost_{n}_cheapest_min_{self.threshold}{suffix}"] = np.mean(
                min_sorted_values[:n]
            )

        results[f"num_scaffolds_{self.threshold}"] = len(self.scaffold_to_mean_cost)
        return results


@gin.configurable()
class TrajectoryCost(MetricsBase):
    def __init__(self, n_forward: int, path_cost_proxy: PathCostProxy):
        super().__init__()
        self.n_forward = n_forward
        self.path_cost_proxy = path_cost_proxy

    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        forward_trajectories_mask = np.array(
            [True] * self.n_forward + [False] * (len(trajectories) - self.n_forward)
        )
        forward_trajectories = trajectories.masked_select(forward_trajectories_mask)
        forward_costs = self.path_cost_proxy._compute_costs(forward_trajectories)
        result = {}
        result["forward_mean_cost"] = np.mean(forward_costs).item()
        if np.sum(~forward_trajectories_mask) > 0:
            replay_trajectories = trajectories.masked_select(~forward_trajectories_mask)
            replay_costs = self.path_cost_proxy._compute_costs(replay_trajectories)
            result["replay_mean_cost"] = np.mean(replay_costs).item()

        return result


@gin.configurable()
class ScaffoldCostsList(MetricsBase):
    def __init__(
        self,
        path_cost_proxy: PathCostProxy,
        proxy_value_threshold_list: List[float] = (8,),
        proxy_component_name: str | None = None,
        n_cheapest_list: List = (100,),
    ):
        super().__init__()
        self.metrics = MetricsList(
            [
                ScaffoldCost(
                    path_cost_proxy=path_cost_proxy,
                    threshold=threshold,
                    proxy_component_name=proxy_component_name,
                    n_cheapest_list=n_cheapest_list,
                    forward_only=forward_only,
                )
                for threshold in proxy_value_threshold_list
                for forward_only in [False, True]
            ]
        )

    def compute_metrics(self, trajectories) -> Dict[str, float]:
        return self.metrics.compute_metrics(trajectories)


@gin.configurable()
class NumScaffoldsFound(MetricsBase):
    def __init__(
        self,
        proxy_value_threshold_list: List[float],
        proxy_component_name: str | None,
        proxy_higher_better: bool = True,
    ):
        super().__init__()
        self.proxy_value_threshold_list = proxy_value_threshold_list
        self.proxy_higher_better = proxy_higher_better
        self.threshold_to_set: Dict[float, Set[str]] = {
            threshold: set() for threshold in proxy_value_threshold_list
        }
        self.proxy_component_name = proxy_component_name

    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        reward_outputs = trajectories.get_reward_outputs()
        terminal_states = trajectories.get_last_states_flat()
        values = (
            reward_outputs.proxy
            if self.proxy_component_name is None
            else reward_outputs.proxy_components[self.proxy_component_name]
        )
        for state, proxy_value in zip(terminal_states, values):
            for threshold in self.proxy_value_threshold_list:
                if (self.proxy_higher_better and proxy_value.item() > threshold) or (
                    not self.proxy_higher_better and proxy_value.item() < threshold
                ):
                    self.threshold_to_set[threshold].add(
                        MurckoScaffoldSmiles(state.molecule.smiles)
                    )

        return {
            f"num_scaffolds_{threshold}": len(self.threshold_to_set[threshold])
            for threshold in self.proxy_value_threshold_list
        }


@gin.configurable()
class UniqueMolecules(MetricsBase):
    def __init__(self, run_dir: str, dump_every_n: int | None = None):
        super().__init__()
        self.molecules: Dict[Any, Any] = {}
        self.dump_every_n = dump_every_n
        self.iterations = 0
        self.dump_path = Path(run_dir) / "unique_molecules"
        self.dump_path.mkdir(exist_ok=True, parents=True)

    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        terminal_states = trajectories.get_last_states_flat()
        proxy_scores = trajectories.get_reward_outputs().proxy
        proxy_terms = trajectories.get_reward_outputs().proxy_components
        for i, state in enumerate(terminal_states):
            if isinstance(state, ReactionStateTerminal):
                output = {"score": proxy_scores[i].item()}
                if proxy_terms is not None:
                    for name, values in proxy_terms.items():
                        output[f"term_{name}"] = values[i].item()

                self.molecules[state.molecule.smiles] = output

        if (
            self.dump_every_n is not None
            and (self.iterations % self.dump_every_n == 0)
            and self.iterations > 0
        ):
            with open(self.dump_path / f"molecules_{self.iterations}.json", "w") as fp:
                json.dump(self.molecules, fp)
        self.iterations += 1

        return {"num_unique_molecules": len(self.molecules)}


@gin.configurable()
class AllMolecules(MetricsBase):
    def __init__(self, run_dir: str, dump_every_n: int | None = None):
        super().__init__()
        self.molecules: list = []
        self.dump_every_n = dump_every_n
        self.iterations = 0
        self.dump_path = Path(run_dir) / "all_molecules"
        self.dump_path.mkdir(exist_ok=True, parents=True)

    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        terminal_states = trajectories.get_last_states_flat()
        proxy_scores = trajectories.get_reward_outputs().proxy
        proxy_terms = trajectories.get_reward_outputs().proxy_components
        for i, state in enumerate(terminal_states):
            if isinstance(state, ReactionStateTerminal):
                output = {"score": proxy_scores[i].item()}
                if proxy_terms is not None:
                    for name, values in proxy_terms.items():
                        output[f"term_{name}"] = values[i].item()

                self.molecules.append((state.molecule.smiles, output))

        if (
            self.dump_every_n is not None
            and (self.iterations % self.dump_every_n == 0)
            and self.iterations > 0
        ):
            with open(self.dump_path / f"molecules_{self.iterations}.txt", "w") as fp:
                for smiles, output in self.molecules:
                    fp.write(f"{smiles}, {output}\n")
        self.iterations += 1

        return {"num_visited_molecules": len(self.molecules)}


@gin.configurable()
class TanimotoSimilarityModes(MetricsBase):
    def __init__(
        self,
        run_dir: str,
        proxy: CachedProxyBase,
        term_name: str = "value",
        proxy_term_threshold: float = -np.inf,
        similarity_threshold: float = 0.7,
        max_modes: int | None = 5000,
        compute_every_n: int = 1,
    ):
        super().__init__()
        self.proxy = proxy
        self.term_name = term_name
        self.proxy_term_threshold = proxy_term_threshold
        self.similarity_threshold = similarity_threshold
        self.max_modes = max_modes
        self.compute_every_n = compute_every_n
        self.iterations = 0
        self.dump_path = Path(run_dir) / "modes"
        self.dump_path.mkdir(exist_ok=True, parents=True)
        self.xlsx_path = None

    def _extract_top_sorted_smiles(self) -> Dict[str, float | Dict[str, float]]:
        """
        Fetches SMILES from proxy cache, extracts the ones with reward above thresholds,
        and sorts them by reward.
        """
        if isinstance(next(iter(self.proxy.cache.values())), float):
            cache = {k: {"value": v} for k, v in self.proxy.cache.items()}
        else:
            cache = self.proxy.cache

        d = {}
        for state, scores in cache.items():
            if (
                isinstance(state, ReactionStateTerminal)
                and scores[self.term_name] >= self.proxy_term_threshold
            ):
                d[state.molecule.smiles] = scores
        d = dict(sorted(d.items(), key=lambda item: item[1][self.term_name], reverse=True))

        return d

    def _extract_modes(self) -> Dict[str, float | Dict[str, float]]:
        d = self._extract_top_sorted_smiles()
        mols = [Chem.MolFromSmiles(x) for x in d.keys()]
        ecfps = [
            AllChem.GetMorganFingerprintAsBitVect(
                m, radius=3, nBits=2048, useFeatures=False, useChirality=False
            )
            for m in mols
        ]
        modes = []
        for mol, ecfp, r, smiles in zip(mols, ecfps, d.values(), d.keys()):
            if len(modes) >= self.max_modes:
                break
            is_mode = True
            for mode in modes:
                if DataStructs.TanimotoSimilarity(ecfp, mode[1]) > self.similarity_threshold:
                    is_mode = False
                    break
            if is_mode:
                modes.append((mol, ecfp, r, smiles))
        return {m[3]: m[2] for m in modes}

    @staticmethod
    def _modes_to_df(modes: Dict[str, float | Dict[str, float]]) -> pd.DataFrame:
        reward_terms = [k for k in next(iter(modes.values())).keys() if k != "value"]

        rows = []
        for smiles, scores in modes.items():
            reward = scores["value"]
            mol = Chem.MolFromSmiles(smiles)
            heavy_atoms = mol.GetNumHeavyAtoms()
            efficiency = reward / heavy_atoms
            row = (
                [
                    "",
                    f"{np.round(reward, 2):.2f}",
                ]
                + [f"{np.round(scores[term], 2):.2f}" for term in reward_terms]
                + [
                    f"{np.round(Descriptors.ExactMolWt(mol), 2):.2f}",
                    Lipinski.NumHDonors(mol),
                    Lipinski.NumHAcceptors(mol),
                    heavy_atoms,
                    f"{np.round(Descriptors.MolLogP(mol), 3):.3f}",
                    Chem.rdMolDescriptors.CalcNumRotatableBonds(mol),
                    f"{np.round(efficiency, 4):.4f}",
                    smiles,
                ]
            )
            rows.append(row)

        columns = (
            [
                "Molecule",
                "Reward",
            ]
            + [f"Reward ({term})" for term in reward_terms]
            + [
                "MW",
                "H-bond donors",
                "H-bond acceptors",
                "Heavy atoms",
                "cLogP",
                "Rotatable bonds",
                "Ligand efficiency",
                "SMILES",
            ]
        )

        return pd.DataFrame(rows, columns=columns)

    @staticmethod
    def _save_modes_xlsx(df: pd.DataFrame, file_path: Path | str):
        writer = pd.ExcelWriter(file_path, engine="xlsxwriter")
        df.to_excel(writer, sheet_name="Molecules", index=False)
        worksheet = writer.sheets["Molecules"]
        worksheet.set_column(0, 0, 21)
        worksheet.set_column(1, len(df.columns), 15)

        directory = TemporaryDirectory()

        for i, row in df.iterrows():
            mol = Chem.MolFromSmiles(row["SMILES"])
            image_path = Path(directory.name) / f"molecule_{i}.png"
            Draw.MolToFile(mol, filename=image_path, size=(150, 150))

            worksheet.set_row(i + 1, 120)
            worksheet.insert_image(i + 1, 0, image_path)

        writer.book.close()
        directory.cleanup()

    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        if self.iterations % self.compute_every_n == 0:
            modes = self._extract_modes()

            if len(modes) > 0:
                df = self._modes_to_df(modes)
                self.xlsx_path = self.dump_path / f"modes_{self.iterations}.xlsx"
                self._save_modes_xlsx(df, self.xlsx_path)

            self.iterations += 1

            return {"num_modes": len(modes)}
        else:
            self.iterations += 1

            return {}

    def collect_files(self) -> List[Path | str]:
        if self.xlsx_path is None:
            return []
        else:
            result = [self.xlsx_path]
            self.xlsx_path = None
            return result


@gin.configurable()
class FractionEarlyTerminate(MetricsBase):
    def compute_metrics(self, trajectories: Trajectories) -> Dict[str, float]:
        terminal_states = trajectories.get_last_states_flat()
        num_early_terminate = sum(
            [1 for state in terminal_states if isinstance(state, ReactionStateEarlyTerminal)]
        )
        return {"fraction_early_terminate": num_early_terminate / len(terminal_states)}
