import dataclasses

from typing import (
    Any,
    Callable,
    Iterable,
    Mapping,
    Optional,
    Type,
    Union,
)

import gym
import numpy as np
import torch
import tqdm
from stable_baselines3.common import policies, utils, vec_env

from delphicORL.algos import base as algo_base
from delphicORL.utils import scorers


class BC(algo_base.DemonstrationAlgorithm):
    """Behavioral cloning (BC).

    Recovers a policy via supervised learning from observation-action pairs.
    """

    def __init__(
        self,
        *,
        observation_space: gym.Space,
        action_space: gym.Space,
        policy: Optional[policies.ActorCriticPolicy] = None,
        demonstrations = None,
        test_demonstrations = None,
        batch_size: int = 32,
        optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Mapping[str, Any]] = {},
        ent_weight: float = 1e-3,
        l2_weight: float = 0.0,
        device: Union[str, torch.device] = "auto",
        custom_logger = None,
        lstm=False,
        lstm_model = 'lstm',
        input_name='obs',
        ope = False
    ):
        """Builds BC.

        Args:
            observation_space: the observation space of the environment.
            action_space: the action space of the environment.
            policy: a Stable Baselines3 policy; if unspecified,
                defaults to `FeedForward32Policy`.
            demonstrations: Demonstrations from an expert (optional). Transitions
                expressed directly as a `types.TransitionsMinimal` object, a sequence
                of trajectories, or an iterable of transition batches (mappings from
                keywords to arrays containing observations, etc).
            batch_size: The number of samples in each batch of expert data.
            optimizer_cls: optimiser to use for supervised training.
            optimizer_kwargs: keyword arguments, excluding learning rate and
                weight decay, for optimiser construction.
            ent_weight: scaling applied to the policy's entropy regularization.
            l2_weight: scaling applied to the policy's L2 regularization.
            device: name/identity of device to place policy on.
            custom_logger: Where to log to; if None (default), creates a new logger.

        Raises:
            ValueError: If `weight_decay` is specified in `optimizer_kwargs` (use the
                parameter `l2_weight` instead.)
        """
        self.batch_size = batch_size
        self.lstm=lstm
        self.lstm_model=lstm_model
        self.input_name=input_name
        
        super().__init__(
            demonstrations=demonstrations,
            custom_logger=custom_logger,
            test_demonstrations=test_demonstrations,
        )

        self.action_space = action_space
        self.observation_space = observation_space
        self.discrete = isinstance(action_space, gym.spaces.Discrete)

        self.policy = algo_base.get_policy(observation_space, action_space, lstm, lstm_model).to(utils.get_device(device))

        self.optimizer = optimizer_cls(
            self.policy.parameters(),
            **optimizer_kwargs,
        )

        self.ent_weight = ent_weight
        self.l2_weight = l2_weight

        self.ope = ope
        self.reward_fn = None
        self.env = None

    def loss(self,
            policy: policies.ActorCriticPolicy,
            obs: Union[torch.Tensor, np.ndarray],
            acts: Union[torch.Tensor, np.ndarray],
            masks = None
        ) -> dict:
        
        if masks is not None:
            _, log_prob, entropy = policy.evaluate_actions(obs, acts, masks)
        else:
            _, log_prob, entropy = policy.evaluate_actions(obs, acts)
        prob_true_act = torch.exp(log_prob).mean()
        log_prob = log_prob.mean()
        entropy = entropy.mean()

        l2_norms = [torch.sum(torch.square(w)) for w in policy.parameters()]
        l2_norm = sum(l2_norms) / 2  # divide by 2 to cancel with gradient of square

        ent_loss = -self.ent_weight * entropy
        neglogp = -log_prob
        l2_loss = self.l2_weight * l2_norm
        loss = neglogp + ent_loss + l2_loss

        return dict(
            neglogp=neglogp,
            entropy=entropy,
            ent_loss=ent_loss,
            prob_true_act=prob_true_act,
            l2_norm=l2_norm,
            l2_loss=l2_loss,
            loss=loss,
        )

    def trainer(self, batch) -> dict:
        input_output = algo_base.get_bc_input_output(batch, self.input_name, device= self.policy.device, lstm=self.lstm)
       
        self.policy.train()
        training_metrics = self.loss(self.policy, *input_output)

        self.optimizer.zero_grad()
        training_metrics['loss'].backward()
        self.optimizer.step()

        return training_metrics


    def train(
        self,
        *,
        n_epochs: Optional[int] = None,
        log_interval: int = 500,
        log_rollouts_venv: Optional[vec_env.VecEnv] = None,
        log_rollouts_n_episodes: int = 5,
    ):
        """Train with supervised learning for some number of epochs.

        Here an 'epoch' is just a complete pass through the expert data loader,
        as set by `self.set_expert_data_loader()`. Note, that when you specify
        `n_batches` smaller than the number of batches in an epoch, the `on_epoch_end`
        callback will never be called.

        Args:
            n_epochs: Number of complete passes made through expert data before ending
                training. Provide exactly one of `n_epochs` and `n_batches`.
            n_batches: Number of batches loaded from dataset before ending training.
                Provide exactly one of `n_epochs` and `n_batches`.
            on_epoch_end: Optional callback with no parameters to run at the end of each
                epoch.
            on_batch_end: Optional callback with no parameters to run at the end of each
                batch.
            log_interval: Log stats after every log_interval batches.
            log_rollouts_venv: If not None, then this VecEnv (whose observation and
                actions spaces must match `self.observation_space` and
                `self.action_space`) is used to generate rollout stats, including
                average return and average episode length. If None, then no rollouts
                are generated.
            log_rollouts_n_episodes: Number of rollouts to generate when calculating
                rollout stats. Non-positive number disables rollouts.
            progress_bar: If True, then show a progress bar during training.
            reset_tensorboard: If True, then start plotting to Tensorboard from x=0
                even if `.train()` logged to Tensorboard previously. Has no practical
                effect if `.train()` is being called for the first time.
        """
        

        evaluation_scorer = scorers.imitation_rollout_stats(log_rollouts_venv,log_rollouts_n_episodes)
        if self.discrete:
            imitation_scorer = scorers.discrete_action_diff_scorer
        else:
            imitation_scorer = scorers.continuous_action_diff_scorer

        if self.ope:
            ope_scorer = scorers.d3rlpy_ope_scorer(
                        self.env, self._demo_data_loader['infos'])

        batches_with_stats = self.setup_training(n_epochs=n_epochs)

        for (batch_num, batch_size, num_samples_so_far), batch in batches_with_stats:
            loss = self.trainer(batch)

            if batch_num % log_interval == 0:
                rollout_stats = evaluation_scorer(self.policy)
                train_imit_stats = imitation_scorer(self.policy, self._demo_data_loader, input_name=self.input_name, lstm=self.lstm)
                test_imit_stats = imitation_scorer(self.policy, self._test_demo_data_loader, input_name=self.input_name, lstm=self.lstm)

                if self.ope:
                    loss['ope_dm_return'] = ope_scorer(self.policy, self._test_demo_data_loader)

                self._logger.log_batch(
                    batch_num,
                    batch_size,
                    num_samples_so_far,
                    loss,
                    train_imit_stats,
                    test_imit_stats,
                    rollout_stats,
                )