import os
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Tuple

import numpy as np
import torch
from torch_geometric.data import Batch, Data
from tqdm import tqdm

from ltsgns_mp.architectures.simulators.abstract_simulator import AbstractSimulator
from ltsgns_mp.envs.env import Env
from ltsgns_mp.util import keys
from ltsgns_mp.util.own_types import ConfigDict, ValueDict


class AbstractAlgorithm(ABC):
    """
    Abstract class for the full algorithm, including the Simulator.
    """
    def __init__(self, config: ConfigDict, simulator: AbstractSimulator,
                 env: Env, optimizer: torch.optim.Optimizer, loading_config: ConfigDict, device: str):
        self._config: ConfigDict = config
        self._simulator = simulator
        self._env = env
        self._optimizer = optimizer
        self._device: str = device

        self._best_eval_loss = np.inf
        self._save_checkpoint_this_epoch = False

    def train_step(self, epoch: int) -> ValueDict:
        """
        Performs one epoch of training. Computes the mean train loss and returns it as a ValueDict.
        Returns: The training metrics as a ValueDict.

        """
        self.simulator.train()
        scalar_metrics = defaultdict(list)
        if not self.config.verbose:
            # kinda weird, but in not verbose mode, you only print this one line instead of the progress bar
            print(f"Training Epoch {epoch}...")
        for batch in tqdm(self.env.train_iterator, desc=f"Training Epoch {epoch}", disable=not self.config.verbose):
            step_loss = self._single_train_step(batch)
            scalar_metrics[keys.TOTAL_LOSS].append(step_loss)

        # take the average of the training metrics over the training step
        training_metrics = {keys.SCALARS: {keys.TRAIN: {keys.ALL_TRAIN_TASKS: {key: float(np.mean(value))
                                                                               for key, value in
                                                                               scalar_metrics.items()}}}}
        return training_metrics


    @abstractmethod
    def _single_train_step(self, batch: Batch) -> torch.Tensor:
        """
        Performs a single training step. This means, compute the loss and perform a gradient step.
        Returns: The loss of the training step.

        """
        raise NotImplementedError

    @abstractmethod
    def predict_trajectory(self, traj: Data, visualize: bool = False, eval_only: bool = False) -> Tuple[torch.Tensor, ValueDict]:
        """
        Predicts a trajectory for the given trajectory.
        Returns: A tuple consisting of the predicted trajectory positions. Shape: [num_timesteps, num_nodes, world_dim]
                    and a dictionary containing visualizations and other logging information.
                    (Should only be non-empty if visualize is True)
        """
        raise NotImplementedError


    def save_checkpoint(self, directory: str, iteration: int, is_initial_save: bool, is_final_save: bool = False):
        # simulator
        self.simulator.save_checkpoint(directory, iteration, is_initial_save, is_final_save)
        # optimizer
        if is_final_save:
            optimizer_file_name = f"{keys.OPTIMIZER}_final.pt"
        else:
            optimizer_file_name = f"{keys.OPTIMIZER}_{iteration}.pt"
        torch.save(self._optimizer.state_dict(), os.path.join(directory, optimizer_file_name))

    @property
    def save_checkpoint_this_epoch(self) -> bool:
        return self._save_checkpoint_this_epoch

    @save_checkpoint_this_epoch.setter
    def save_checkpoint_this_epoch(self, value: bool):
        self._save_checkpoint_this_epoch = value

    def update_best_eval_loss(self, small_eval_metrics):
        raise NotImplementedError

    @property
    def simulator(self) -> AbstractSimulator:
        if self._simulator is None:
            raise ValueError("Simulator not set")
        return self._simulator

    @property
    def config(self) -> ConfigDict:
        return self._config

    @property
    def env(self) -> Env:
        if self._env is None:
            raise ValueError("Env not set")
        return self._env

    def _apply_loss(self, loss: torch.Tensor):
        """
        Applies the loss to the parameters of the simulator.
        :param loss: The loss to be applied.
        """
        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()
