from collections import defaultdict
from typing import Any, Dict, List

import gin
import numpy as np
import torch

from rgfn.api.reward import Reward
from rgfn.api.reward_output import RewardOutput
from rgfn.api.training_hooks_mixin import TrainingHooksMixin
from rgfn.api.trajectories import Trajectories, TrajectoriesContainer
from rgfn.gfns.reaction_gfn.api.data_structures import (
    AnchoredReaction,
    Molecule,
    ParallelCache,
)
from rgfn.gfns.reaction_gfn.api.reaction_api import (
    ReactionAction0,
    ReactionAction0Invalid,
    ReactionActionC,
    ReactionState0Invalid,
    ReactionStateTerminal,
)
from rgfn.gfns.reaction_gfn.api.reaction_data_factory import ReactionDataFactory


@gin.configurable()
class PathCostProxy(TrainingHooksMixin):
    def __init__(
        self,
        data_factory: ReactionDataFactory,
        yield_value: float | None = None,
        update_reward: bool = False,
        target_weight: float = 1.0,
        cost_weight: float = 1.0,
        cost_temperature: float = 1.0,
        reward: Reward | None = None,
    ):
        self.data_factory = data_factory
        self.fragment_to_cost = data_factory.get_fragment_to_cost()
        self.fragment_smiles_to_cost = {
            fragment.smiles: cost for fragment, cost in self.fragment_to_cost.items()
        }
        self.reaction_to_yield = data_factory.get_reaction_to_yield()
        self.anchor_to_reaction = data_factory.get_anchor_to_reaction_map()
        assert len(self.reaction_to_yield) > 0 or yield_value is not None
        self.forward_terminal_to_cost: Dict[str, float] = defaultdict(lambda: float("inf"))
        self.replay_terminal_to_cost: Dict[str, float] = defaultdict(lambda: float("inf"))
        self.molecule_num_reaction_to_cost = ParallelCache(max_size=1_000_000)
        self.n_recent_updates = 0
        self.negative_molecule_num_reaction = ParallelCache(max_size=50_000)
        self.yield_value = yield_value
        self.update_reward = update_reward
        self.target_weight = target_weight
        self.cost_weight = cost_weight
        self.cost_temperature = cost_temperature
        self.beta = reward.beta if reward is not None else 1.0
        print("Cost mean and variance", self.get_fragment_costs_mean_std())
        print("Cost max", self.get_fragment_costs_max())

    def get_fragment_costs_mean_std(self) -> tuple[float, float]:
        costs = list(self.fragment_to_cost.values())
        return np.mean(costs), np.std(costs)

    def get_fragment_costs_max(self) -> float:
        return max(self.fragment_to_cost.values())

    def _compute_costs(self, trajectories: Trajectories) -> List[float]:
        path_costs = []
        for actions, states in zip(
            trajectories.get_all_actions_grouped(),
            trajectories.get_all_states_grouped(),
        ):
            current_cost = (
                self.get_action_cost(actions[0])
                if not isinstance(states[0], ReactionState0Invalid)
                else float("inf")
            )
            for action, state in zip(actions, states[1:]):
                if isinstance(action, ReactionAction0Invalid):
                    if state.num_reactions == 0:
                        raise ValueError(f"States with num_reactions == 0, {state}")
                    item = (state.molecule.smiles, state.num_reactions)
                    self.negative_molecule_num_reaction[item] = float("inf")
                    self.molecule_num_reaction_to_cost[item] = float("inf")
                elif isinstance(action, ReactionActionC):
                    if state.num_reactions == 0:
                        raise ValueError(f"States with num_reactions == 0, {state}")
                    fragment_cost = self.get_action_cost(action)
                    yield_value = self.compute_yield(action)
                    current_cost = (current_cost + fragment_cost) * yield_value**-1
                    item = (state.molecule.smiles, state.num_reactions)
                    previous_cost = self.molecule_num_reaction_to_cost[item] or float("inf")
                    cost = min(previous_cost, current_cost)
                    if previous_cost > current_cost:
                        self.molecule_num_reaction_to_cost[item] = current_cost
                        self.n_recent_updates += 1
                    if cost == float("inf"):
                        self.negative_molecule_num_reaction[item] = float("inf")
                    else:
                        self.negative_molecule_num_reaction.pop(item)

            path_costs.append(current_cost)
        return path_costs

    def _update_reward(self, trajectories: Trajectories) -> None:
        if not self.update_reward or len(trajectories) == 0:
            return
        reward_output: RewardOutput = trajectories.get_reward_outputs()
        assert reward_output.proxy_components is None
        target_value = reward_output.proxy
        cost_value = torch.tensor(trajectories.get_costs(), device=target_value.device)
        inverse_cost_value = torch.exp(-cost_value * self.cost_temperature)
        proxy = target_value * self.target_weight + inverse_cost_value * self.cost_weight
        log_reward = proxy * self.beta
        new_reward_output = RewardOutput(
            log_reward=log_reward,
            reward=torch.exp(log_reward),
            proxy=proxy,
            proxy_components={"target": target_value, "inverse_cost": inverse_cost_value},
        )
        trajectories.set_reward_outputs(new_reward_output)

    def assign_costs(self, trajectories_container: TrajectoriesContainer) -> Dict[str, Any]:
        self.n_recent_updates = 0
        if trajectories_container.forward_trajectories is not None:
            costs = self._compute_costs(trajectories_container.forward_trajectories)
            trajectories_container.forward_trajectories.set_costs(costs)
            self._update_reward(trajectories_container.forward_trajectories)
            for terminal_state, cost in zip(
                trajectories_container.forward_trajectories.get_last_states_flat(),
                costs,
            ):
                if isinstance(terminal_state, ReactionStateTerminal):
                    self.forward_terminal_to_cost[terminal_state.molecule.smiles] = min(
                        self.forward_terminal_to_cost[terminal_state.molecule.smiles],
                        cost,
                    )
        if trajectories_container.replay_trajectories is not None:
            costs = self._compute_costs(trajectories_container.replay_trajectories)
            trajectories_container.replay_trajectories.set_costs(costs)
            self._update_reward(trajectories_container.replay_trajectories)
            for terminal_state, cost in zip(
                trajectories_container.replay_trajectories.get_last_states_flat(), costs
            ):
                if isinstance(terminal_state, ReactionStateTerminal):
                    self.replay_terminal_to_cost[terminal_state.molecule.smiles] = min(
                        self.replay_terminal_to_cost[terminal_state.molecule.smiles],
                        cost,
                    )
        if trajectories_container.backward_trajectories is not None:
            costs = self._compute_costs(trajectories_container.backward_trajectories)
            trajectories_container.backward_trajectories.set_costs(costs)
            self._update_reward(trajectories_container.backward_trajectories)
        return {"cost_updates": self.n_recent_updates}

    def compute_yield(self, action: ReactionActionC) -> float:
        if self.yield_value is not None:
            return self.yield_value
        reaction = self.anchor_to_reaction[action.input_reaction]
        return self.reaction_to_yield[reaction]

    def compute_yield_raw(
        self, input_smiles_list: List[str], output_smiles: str, reaction: str
    ) -> float:
        if self.yield_value is not None:
            return self.yield_value
        anchored_reaction = AnchoredReaction(reaction, 0, 0)
        reaction = self.anchor_to_reaction[anchored_reaction]
        return self.reaction_to_yield[reaction]

    def get_fragment_cost(self, fragment: Molecule | str) -> float:
        mol = fragment if isinstance(fragment, Molecule) else Molecule(fragment)
        return self.fragment_to_cost[mol]

    def get_fragment_smiles_cost(self, fragment: str) -> float:
        return self.fragment_smiles_to_cost[fragment]

    def get_action_cost(self, action: ReactionActionC | ReactionAction0) -> float:
        if isinstance(action, ReactionAction0):
            return self.fragment_to_cost[action.fragment]
        else:
            return sum(self.fragment_to_cost[fragment] for fragment in action.input_fragments)

    def on_update_fragments_library(
        self,
        iteration_idx: int,
        fragments: List[Molecule],
        costs: List[float],
        recursive: bool = True,
    ) -> Dict[str, Any]:
        for fragment, cost in zip(fragments, costs):
            self.fragment_to_cost[fragment] = cost
        return {}
