from __future__ import annotations
from dataclasses import dataclass

import logging
import os
import sys
import time
from abc import abstractclassmethod, abstractmethod
from collections import defaultdict
from functools import partial
from typing import (
  Tuple,
  Dict,
  List,
  Optional,
  DefaultDict,
  Callable,
  Iterable,
  Union,
  Type,
  Any,
  Generic,
  TypeVar,
)
from typing_extensions import Literal

import numpy as np
import torch
from pyprojroot import here as project_root

sys.path.insert(0, str(project_root()))

from fs_mol.data import (
  FSMolBatcher,
  FSMolBatchIterable,
  FSMolTaskSample,
)
from fs_mol.utils.logging import PROGRESS_LOG_LEVEL
from fs_mol.utils.metric_logger import MetricLogger
from fs_mol.utils.metrics import (
  avg_task_metrics_list,
  compute_metrics,
  BinaryEvalMetrics,
  BinaryMetricType,
)

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class TorchFSMolModelOutput:
  # Predictions for each input molecule, as a [NUM_MOLECULES, 1] float tensor
  molecule_binary_label: torch.Tensor


@dataclass
class TorchFSMolModelLoss:
  label_loss: torch.Tensor

  @property
  def total_loss(self) -> torch.Tensor:
    return self.label_loss

  @property
  def metrics_to_log(self) -> Dict[str, Any]:
    return {"total_loss": self.total_loss, "label_loss": self.label_loss}


BatchFeaturesType = TypeVar("BatchFeaturesType")
BatchOutputType = TypeVar("BatchOutputType", bound=TorchFSMolModelOutput)
BatchLossType = TypeVar("BatchLossType", bound=TorchFSMolModelLoss)
MetricType = Union[BinaryMetricType, Literal["loss"]]
ModelStateType = Dict[str, Any]


class AbstractTorchFSMolModel(
  Generic[BatchFeaturesType, BatchOutputType, BatchLossType], torch.nn.Module
):
  def __init__(self):
    super().__init__()
    self.criterion = torch.nn.BCEWithLogitsLoss(reduction="none")

  @abstractmethod
  def forward(self, batch: BatchFeaturesType) -> BatchOutputType:
    """
    Given the features of a batch of molecules, compute proability of each of these molecules
    having an "active" label in the currently learned assay.

    Args:
        batch: representation of the features of NUM_MOLECULES, as chosen by the implementor.

    Returns:
        Model output, a subtype of TorchFSMolModelOutput, ensuring that at least molecule_binary_label
        is present.
    """
    raise NotImplementedError()

  def compute_loss(
          self, batch: BatchFeaturesType, model_output: BatchOutputType, labels: torch.Tensor
  ) -> BatchLossType:
    """
    Compute loss; can be overwritten by implementor to implement extra objectives.

    Args:
        batch: representation of the features of NUM_MOLECULES, as chosen by the implementor.
        labels: float Tensor of shape [NUM_MOLECULES], indicating the target label of the each molecule.
        model_output: output of the model, as chosen by the implementor.

    Returns:
        Dictionary mapping partial loss names to the loss. Optimization will be performed over the sum of values.
    """
    predictions = model_output.molecule_binary_label.squeeze(dim=-1)
    label_loss = torch.mean(self.criterion(predictions, labels.float()))
    return TorchFSMolModelLoss(label_loss=label_loss)

  @abstractmethod
  def get_model_state(self) -> ModelStateType:
    """
    Return the state of the model such as configuration and learnable parameters.

    Returns:
        Dictionary
    """
    raise NotImplementedError()

  @abstractmethod
  def load_model_state(
          self,
          model_state: ModelStateType,
          load_task_specific_weights: bool,
          quiet: bool = False,
  ) -> None:
    """Load model weights from a model state as generated by get_model_state.

    Args:
        model_state: a dictionary representing model state, as returned by model.get_model_state().
        load_task_specific_weights: a flag specifying whether, if applicable, task-specific weights
            should be loaded or not. This would be False for the case of loading the weights when
            transferring the model to a new task.
        quiet: flag indicating if the loading should report additional details (e.g., which weights
            have been loaded / re-initialized).
    """
    raise NotImplementedError()

  @abstractmethod
  def is_param_task_specific(self, param_name: str) -> bool:
    raise NotImplementedError()

  @abstractclassmethod
  def build_from_model_file(
          cls,
          model_file: str,
          config_overrides: Dict[str, Any] = {},
          quiet: bool = False,
          device: Optional[torch.device] = None,
  ) -> AbstractTorchFSMolModel[BatchFeaturesType, BatchOutputType, BatchLossType]:
    """Build the model architecture based on a saved checkpoint."""
    raise NotImplementedError()


def linear_warmup(cur_step: int, warmup_steps: int = 0) -> float:
  if cur_step >= warmup_steps:
    return 1.0
  return cur_step / warmup_steps


def create_optimizer(
        model: AbstractTorchFSMolModel[BatchFeaturesType, BatchOutputType, BatchLossType],
        lr: float = 0.001,
        task_specific_lr: float = 0.005,
        warmup_steps: int = 1000,
        task_specific_warmup_steps: int = 100,
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
  # Split parameters into shared and task-specific ones:
  shared_parameters, task_spec_parameters = [], []
  for param_name, param in model.named_parameters():
    if model.is_param_task_specific(param_name):
      task_spec_parameters.append(param)
    else:
      shared_parameters.append(param)

  opt = torch.optim.Adam(
    [
      {"params": task_spec_parameters, "lr": task_specific_lr},
      {"params": shared_parameters, "lr": lr},
    ],
  )

  scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer=opt,
    lr_lambda=[
      partial(
        linear_warmup, warmup_steps=task_specific_warmup_steps
      ),  # for task specific paramters
      partial(linear_warmup, warmup_steps=warmup_steps),  # for shared paramters
    ],
  )

  return opt, scheduler


def save_model(
        path: str,
        model: AbstractTorchFSMolModel[BatchFeaturesType, BatchOutputType, BatchLossType],
        optimizer: Optional[torch.optim.Optimizer] = None,
        epoch: Optional[int] = None,
) -> None:
  data = model.get_model_state()

  if optimizer is not None:
    data["optimizer_state_dict"] = optimizer.state_dict()
  if epoch is not None:
    data["epoch"] = epoch

  torch.save(data, path)


def load_model_weights(
        model: AbstractTorchFSMolModel[BatchFeaturesType, BatchOutputType, BatchLossType],
        path: str,
        load_task_specific_weights: bool,
        quiet: bool = False,
        device: Optional[torch.device] = None,
) -> None:
  checkpoint = torch.load(path, map_location=device)
  model.load_model_state(checkpoint, load_task_specific_weights, quiet)


def resolve_starting_model_file(
        model_file: str,
        model_cls: Type[AbstractTorchFSMolModel[BatchFeaturesType, BatchOutputType, BatchLossType]],
        out_dir: str,
        use_fresh_param_init: bool,
        config_overrides: Dict[str, Any] = {},
        device: Optional[torch.device] = None,
) -> str:
  # If we start from a fresh init, create a model, do a random init, and store that away somewhere:
  if use_fresh_param_init:
    logger.info("Using fresh model init.")
    model = model_cls.build_from_model_file(
      model_file=model_file, config_overrides=config_overrides, device=device
    )

    resolved_model_file = os.path.join(out_dir, f"fresh_init.pkl")
    save_model(resolved_model_file, model)

    # Hack to give AML some time to actually save.
    time.sleep(1)
  else:
    resolved_model_file = model_file
    logger.info(f"Using model weights loaded from {resolved_model_file}.")

  return resolved_model_file

def eval_context_model(
        model,
        context_length,
        task_sample: FSMolTaskSample,
        batcher: FSMolBatcher[BatchFeaturesType, torch.Tensor],
        learning_rate: float,
        task_specific_learning_rate: float,
        metric_to_use: MetricType = "avg_precision",
        max_num_epochs: int = 50,
        patience: int = 10,
        seed: int = 0,
        quiet: bool = False,
        device: Optional[torch.device] = None,
) -> BinaryEvalMetrics:
  train_data = FSMolBatchIterable(task_sample.train_samples, batcher, shuffle=True, seed=seed)
  train_batch = train_labels = None
  for batch_idx, (batch, labels) in enumerate(iter(train_data)):
    train_batch = batch
    train_labels = labels
  test_data = FSMolBatchIterable(task_sample.test_samples, batcher)
  model.eval()

  per_task_preds: DefaultDict[int, List[float]] = defaultdict(list)
  per_task_labels: DefaultDict[int, List[float]] = defaultdict(list)

  for batch_idx, (batch, labels) in enumerate(iter(test_data)):
    with torch.no_grad():
      predictions: BatchOutputType = model.forward_test(train_batch, batch, train_labels.to(torch.float32),
                                                        labels.to(torch.float32), context_length=context_length)

    # === Finally, collect per-task results to be used for further eval:
    sample_to_task_id: Dict[int, int] = {}
    if hasattr(batch, "sample_to_task_id"):
      sample_to_task_id = batch.sample_to_task_id
    else:
      # If we don't have a sample task information, just use 0 as default task ID:
      sample_to_task_id = defaultdict(lambda: torch.tensor(0))

    # Apply sigmoid to have predictions in appropriate range for computing (scikit) scores.
    num_samples = labels.shape[0]
    predicted_labels = torch.sigmoid(predictions).detach().cpu()
    for i in range(num_samples):
      task_id = sample_to_task_id[i].item()
      per_task_preds[task_id].append(predicted_labels[i].item())
      per_task_labels[task_id].append(labels[i].item())

  metrics = compute_metrics(per_task_preds, per_task_labels)
  test_loss, _test_metrics = 0.0, metrics

  test_metrics = next(iter(_test_metrics.values()))
  logger.log(PROGRESS_LOG_LEVEL, f" Test loss:                   {float(test_loss):.5f}")
  logger.info(f" Test metrics: {test_metrics}")
  logger.info(
    f"Dataset sample has {task_sample.test_pos_label_ratio:.4f} positive label ratio in test data.",
  )
  logger.log(
    PROGRESS_LOG_LEVEL,
    f"Dataset sample test {metric_to_use}: {getattr(test_metrics, metric_to_use):.4f}",
  )
  return test_metrics

