# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import typing as tp
from collections import defaultdict

import flashy
import torch

from ..optim import ModuleDictEMA
from .utils import copy_state

logger = logging.getLogger(__name__)


class BestStateDictManager(flashy.state.StateDictSource):
    """BestStateDictManager maintains a copy of best state_dict() for registered sources.

    BestStateDictManager has two main attributes:
        states (dict): State dict of the registered StateDictSource.
        param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.

    When registering new sources, the BestStateDictManager will ensure two conflicting sources between
    ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about
    what to consider for best state.

    Args:
        device (torch.device or str): Device on which we keep the copy.
        dtype (torch.dtype): Data type for the state parameters.
    """

    def __init__(
        self,
        device: tp.Union[torch.device, str] = "cpu",
        dtype: tp.Optional[torch.dtype] = None,
    ):
        self.device = device
        self.states: dict = {}
        self.param_ids: dict = defaultdict(dict)
        self.dtype = dtype

    def _get_parameter_ids(self, state_dict):
        return {
            id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)
        }

    def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
        for registered_name, registered_param_ids in self.param_ids.items():
            if registered_name != name:
                overlap = set.intersection(
                    registered_param_ids.keys(), param_ids.keys()
                )
                assert (
                    len(overlap) == 0
                ), f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters"
                f" in {name} and already registered {registered_name}: {' '.join(overlap)}"

    def update(self, name: str, source: flashy.state.StateDictSource):
        if name not in self.states:
            raise ValueError(f"{name} missing from registered states.")
        self.states[name] = copy_state(
            source.state_dict(), device=self.device, dtype=self.dtype
        )

    def register(self, name: str, source: flashy.state.StateDictSource):
        if name in self.states:
            raise ValueError(f"{name} already present in states.")
        # Registering parameter ids for EMA and non-EMA states allows us to check that
        # there is no overlap that would create ambiguity about how to handle the best state
        param_ids = self._get_parameter_ids(source.state_dict())
        if isinstance(source, ModuleDictEMA):
            logger.debug(
                f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params"
            )
            self._validate_no_parameter_ids_overlap(name, param_ids)
            self.param_ids[name] = param_ids
        else:
            logger.debug(
                f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params"
            )
            self._validate_no_parameter_ids_overlap("base", param_ids)
            self.param_ids["base"].update(param_ids)
        # Register state
        self.states[name] = copy_state(
            source.state_dict(), device=self.device, dtype=self.dtype
        )

    def state_dict(self) -> flashy.state.StateDict:
        return self.states

    def load_state_dict(self, state: flashy.state.StateDict):
        for name, sub_state in state.items():
            for k, v in sub_state.items():
                self.states[name][k].copy_(v)
