"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

import copy
import logging
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING

import torch
from torch import nn

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import (
    load_model_and_weights_from_checkpoint,
)

if TYPE_CHECKING:
    from fairchem.core.datasets.atomic_data import AtomicData


class HeadInterface(metaclass=ABCMeta):
    @property
    def use_amp(self):
        return False

    @abstractmethod
    def forward(
        self, data: AtomicData, emb: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        """Head forward.

        Arguments
        ---------
        data: AtomicData
            Atomic systems as input
        emb: dict[str->torch.Tensor]
            Embeddings of the input as generated by the backbone

        Returns
        -------
        outputs: dict[str->torch.Tensor]
            Return one or more targets generated by this head
        """
        return


class BackboneInterface(metaclass=ABCMeta):
    @abstractmethod
    def forward(self, data: AtomicData) -> dict[str, torch.Tensor]:
        """Backbone forward.

        Arguments
        ---------
        data: AtomicData
            Atomic systems as input

        Returns
        -------
        embedding: dict[str->torch.Tensor]
            Return backbone embeddings for the given input
        """
        return


@registry.register_model("hydra")
class HydraModel(nn.Module):
    def __init__(
        self,
        backbone: dict | None = None,
        heads: dict | None = None,
        finetune_config: dict | None = None,
        otf_graph: bool = True,
        pass_through_head_outputs: bool = False,
        freeze_backbone: bool = False,
    ):
        super().__init__()
        self.device = None
        self.otf_graph = otf_graph
        # This is required for hydras with models that have multiple outputs per head, since we will deprecate
        # the old config system at some point, this will prevent the need to make major modifications to the trainer
        # because they all expect the name of the outputs directly instead of the head_name.property_name
        self.pass_through_head_outputs = pass_through_head_outputs

        # if finetune_config is provided, then attempt to load the model from the given finetune checkpoint
        starting_model = None
        if finetune_config is not None:
            # Make it hard to sneak more fields into finetuneconfig
            assert (
                len(set(finetune_config.keys()) - {"starting_checkpoint", "override"})
                == 0
            )
            starting_model: HydraModel = load_model_and_weights_from_checkpoint(
                finetune_config["starting_checkpoint"]
            )
            logging.info(
                f"Found and loaded fine-tuning checkpoint: {finetune_config['starting_checkpoint']} (Note we are NOT loading the training state from this checkpoint, only parts of the model and weights)"
            )
            assert isinstance(
                starting_model, HydraModel
            ), "Can only finetune starting from other hydra models!"
            # TODO this is a bit hacky to overrride attrs in the backbone
            if "override" in finetune_config:
                for key, value in finetune_config["override"].items():
                    setattr(starting_model.backbone, key, value)

        if backbone is not None:
            backbone = copy.deepcopy(backbone)
            backbone_model_name = backbone.pop("model")
            self.backbone: BackboneInterface = registry.get_model_class(
                backbone_model_name
            )(
                **backbone,
            )
        elif starting_model is not None:
            self.backbone = starting_model.backbone
            logging.info(
                f"User did not specify a backbone, using the backbone from the starting checkpoint {self.backbone}"
            )
        else:
            raise RuntimeError(
                "Backbone not specified and not found in the starting checkpoint"
            )

        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        if heads is not None:
            heads = copy.deepcopy(heads)
            # Iterate through outputs_cfg and create heads
            self.output_heads: dict[str, HeadInterface] = {}

            head_names_sorted = sorted(heads.keys())
            assert len(set(head_names_sorted)) == len(
                head_names_sorted
            ), "Head names must be unique!"
            for head_name in head_names_sorted:
                head_config = heads[head_name]
                if "module" not in head_config:
                    raise ValueError(
                        f"{head_name} head does not specify module to use for the head"
                    )

                module_name = head_config.pop("module")
                self.output_heads[head_name] = registry.get_model_class(module_name)(
                    self.backbone,
                    **head_config,
                )

            self.output_heads = torch.nn.ModuleDict(self.output_heads)
        elif starting_model is not None:
            self.output_heads = starting_model.output_heads
            logging.info(
                f"User did not specify heads, using the output heads from the starting checkpoint {self.output_heads}"
            )
        else:
            raise RuntimeError(
                "Heads not specified and not found in the starting checkpoint"
            )

    def forward(self, data: AtomicData):
        # lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device
        if not self.device:
            device_from_tensors = {
                x.device.type for x in data.values() if isinstance(x, torch.Tensor)
            }
            assert (
                len(device_from_tensors) == 1
            ), f"all inputs must be on the same device, found the following devices {device_from_tensors}"
            self.device = device_from_tensors.pop()

        emb = self.backbone(data)
        # Predict all output properties for all structures in the batch for now.
        out = {}
        for k in self.output_heads:
            with torch.autocast(
                device_type=self.device, enabled=self.output_heads[k].use_amp
            ):
                if self.pass_through_head_outputs:
                    out.update(self.output_heads[k](data, emb))
                else:
                    out[k] = self.output_heads[k](data, emb)

        return out


class HydraModelV2(nn.Module):
    def __init__(
        self,
        backbone: BackboneInterface,
        heads: dict[str, HeadInterface],
        freeze_backbone: bool = False,
    ):
        super().__init__()
        self.backbone = backbone
        self.output_heads = torch.nn.ModuleDict(heads)
        self.device = None
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

    def forward(self, data):
        # lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device
        if not self.device:
            device_from_tensors = {
                x.device.type for x in data.values() if isinstance(x, torch.Tensor)
            }
            assert (
                len(device_from_tensors) == 1
            ), f"all inputs must be on the same device, found the following devices {device_from_tensors}"
            self.device = device_from_tensors.pop()

        emb = self.backbone(data)
        # Predict all output properties for all structures in the batch for now.
        out = {}
        for k in self.output_heads:
            with torch.autocast(
                device_type=self.device, enabled=self.output_heads[k].use_amp
            ):
                out[k] = self.output_heads[k](data, emb)
        return out
