from __future__ import annotations
from dataclasses import dataclass

import logging
import os
import sys
import time
from abc import abstractclassmethod, abstractmethod
from functools import partial
from typing import (
    Tuple,
    Dict,
    Optional,
    Iterable,
    Union,
    Type,
    Any,
    Generic,
    TypeVar,
)
from typing_extensions import Literal

import torch
from pyprojroot import here as project_root

sys.path.insert(0, str(project_root()))

from utils.logging import PROGRESS_LOG_LEVEL
from utils.metric_logger import MetricLogger
from utils.metrics import (
    BinaryMetricType,
)

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class TorchFSMolModelOutput:
    # Predictions for each input molecule, as a [NUM_MOLECULES, 1] float tensor
    molecule_binary_label: torch.Tensor


@dataclass
class TorchFSMolModelLoss:
    label_loss: torch.Tensor

    @property
    def total_loss(self) -> torch.Tensor:
        return self.label_loss

    @property
    def metrics_to_log(self) -> Dict[str, Any]:
        return {"total_loss": self.total_loss, "label_loss": self.label_loss}


BatchFeaturesType = TypeVar("BatchFeaturesType")
BatchOutputType = TypeVar("BatchOutputType", bound=TorchFSMolModelOutput)
BatchLossType = TypeVar("BatchLossType", bound=TorchFSMolModelLoss)
MetricType = Union[BinaryMetricType, Literal["loss"]]
ModelStateType = Dict[str, Any]


class AbstractTorchFSMolModel(
    Generic[BatchFeaturesType, BatchOutputType, BatchLossType], torch.nn.Module
):
    def __init__(self):
        super().__init__()
        self.criterion = torch.nn.MSELoss(reduction='none')

    @abstractmethod
    def forward(self, batch: BatchFeaturesType) -> BatchOutputType:
        """
        Given the features of a batch of molecules, compute proability of each of these molecules
        having an "active" label in the currently learned assay.

        Args:
            batch: representation of the features of NUM_MOLECULES, as chosen by the implementor.

        Returns:
            Model output, a subtype of TorchFSMolModelOutput, ensuring that at least molecule_binary_label
            is present.
        """
        raise NotImplementedError()

    def compute_loss(
            self, batch: BatchFeaturesType, model_output: BatchOutputType, labels: torch.Tensor
    ):
        """
        Compute loss; can be overwritten by implementor to implement extra objectives.

        Args:
            batch: representation of the features of NUM_MOLECULES, as chosen by the implementor.
            labels: float Tensor of shape [NUM_MOLECULES], indicating the target label of the each molecule.
            model_output: output of the model, as chosen by the implementor.

        Returns:
            Dictionary mapping partial loss names to the loss. Optimization will be performed over the sum of values.
        """
        predictions = model_output.molecule_binary_label.squeeze(dim=-1)
        label_loss = self.criterion(predictions, labels.float())
        mean_loss = torch.mean(label_loss)
        return TorchFSMolModelLoss(label_loss=mean_loss), label_loss

    @abstractmethod
    def get_model_state(self) -> ModelStateType:
        """
        Return the state of the model such as configuration and learnable parameters.

        Returns:
            Dictionary
        """
        raise NotImplementedError()

    @abstractmethod
    def load_model_state(
            self,
            model_state: ModelStateType,
            load_task_specific_weights: bool,
            quiet: bool = False,
    ) -> None:
        """Load model weights from a model state as generated by get_model_state.

        Args:
            model_state: a dictionary representing model state, as returned by model.get_model_state().
            load_task_specific_weights: a flag specifying whether, if applicable, task-specific weights
                should be loaded or not. This would be False for the case of loading the weights when
                transferring the model to a new task.
            quiet: flag indicating if the loading should report additional details (e.g., which weights
                have been loaded / re-initialized).
        """
        raise NotImplementedError()

    @abstractmethod
    def is_param_task_specific(self, param_name: str) -> bool:
        raise NotImplementedError()

    @abstractclassmethod
    def build_from_model_file(
            cls,
            model_file: str,
            config_overrides: Dict[str, Any] = {},
            quiet: bool = False,
            device: Optional[torch.device] = None,
    ) -> AbstractTorchFSMolModel[BatchFeaturesType, BatchOutputType, BatchLossType]:
        """Build the model architecture based on a saved checkpoint."""
        raise NotImplementedError()




