import abc
from typing import Any, Generic, Iterable, Mapping, Optional, TypeVar, Union, Tuple, Callable, Iterator
import tqdm

import torch
import itertools
import numpy as np
import torch.utils.data as th_data
from stable_baselines3.common import policies


from delphicORL.policies import recurrent as policy_recurrent
from delphicORL.utils import data


def get_bc_input_output(batch, input_name='obs', output_name='acts', device='cpu', lstm=False):
    if isinstance(input_name, str):
        x = torch.as_tensor(batch[input_name])
    else:
        x = [ torch.Tensor(batch[name]) for name in input_name]
        dim = max( [len(el.shape) for el in x] )
        x = [ el.unsqueeze(-1) if len(el.shape) < dim else el for el in x ]
        x = torch.cat(x, -1)
    x = x.detach().to(torch.float32).to(device)
    y = torch.Tensor(batch[output_name]).detach().to(torch.float32).to(device)
    if lstm:
        mask = torch.Tensor(batch['masks']).detach().to(torch.bool).to(device)
        return x, y, mask
    return x, y


def get_policy(observation_space, action_space, lstm, lstm_model):
    if lstm:
        if lstm_model == 'lstm':
            return policy_recurrent.RecurrentActorCriticPolicy(
                observation_space=observation_space,
                action_space=action_space,
            )
        elif lstm_model == 'gru':
            return policy_recurrent.GRUActorCriticPolicy(
                observation_space=observation_space,
                action_space=action_space
            )
        elif lstm_model == 'transformer':
            return policy_recurrent.TransformerActorCriticPolicy(
                observation_space=observation_space,
                action_space=action_space,
            )
        else:
            raise NotImplementedError
    else:
        return policies.ActorCriticPolicy(
            observation_space=observation_space,
            action_space=action_space,
            # Set lr_schedule to max value to force error if policy.optimizer
            # is used by mistake (should use optimizer instead).
            lr_schedule=lambda _: torch.finfo(torch.float32).max,
            net_arch=[32, 32],
        )


class BaseAlgorithm():
    """Base class for all imitation learning algorithms."""
    def __init__(
        self,
        *,
        custom_logger = None,
    ):
        self._logger = custom_logger


class DemonstrationAlgorithm(BaseAlgorithm):
    """An algorithm that learns from demonstration: BC, IRL, etc."""

    def __init__(
        self,
        *,
        demonstrations, custom_logger = None,
        test_demonstrations = None,
        batch_size = 32
    ):
        super().__init__(
            custom_logger=custom_logger,
        )
        self.batch_size = batch_size

        if demonstrations is not None:
            self.set_demonstrations(demonstrations)

        if test_demonstrations is not None:
            self.set_test_demonstrations(test_demonstrations)

    
    def set_demonstrations(self, demonstrations) -> None:
        """Sets the demonstration data.

        Changing the demonstration data on-demand can be useful for
        interactive algorithms like DAgger.

        Args:
             demonstrations: Either a Torch `DataLoader`, any other iterator that
                yields dictionaries containing "obs" and "acts" Tensors or NumPy arrays,
                `TransitionKind` instance, or a Sequence of Trajectory objects.
        """
        self._demo_data_loader = make_data_loader(
            demonstrations, self.batch_size, lstm=self.lstm
        )
    
    def set_test_demonstrations(self, test_demonstrations) -> None:
        self._test_demo_data_loader = make_data_loader( 
            test_demonstrations, self.batch_size, lstm=self.lstm
        )


    def save_policy(self, policy_path):
        torch.save(self.policy, policy_path)


    def setup_training(self, *,
        n_epochs
    ):
        self._logger.log_epoch(0)
        batches_with_stats = self.batch_iterator(n_epochs)

        return batches_with_stats


    def batch_iterator(self, n_epochs):
        num_batches_so_far = 0
        num_samples_so_far = 0
        pbar = tqdm.tqdm(itertools.islice(itertools.count(), n_epochs))
        for epoch_num in pbar:
            for batch in self._demo_data_loader:
                num_batches_so_far += 1
                batch_size = len(batch["obs"])
                num_samples_so_far += batch_size
                yield (num_batches_so_far, batch_size, num_samples_so_far), batch

            pbar.display(f"Epoch {epoch_num} of {n_epochs}",
                    pos=1,
                )
                
            self._logger.log_epoch(epoch_num + 1)
            
            

def unsqueezed_array(array, lstm):
    if (lstm and len(array.shape) < 3) or (len(array.shape) < 2):
        return torch.unsqueeze(array, -1)
    return array



def make_data_loader(
    transitions,
    batch_size: int,
    data_loader_kwargs: Optional[Mapping[str, Any]] = None,
    lstm=False
):
    if isinstance(transitions, Iterable):
        try:
            first_item = next(iter(transitions))
        except StopIteration:
            first_item = None
        if isinstance(first_item, data.Trajectory):
            transitions = data.flatten_trajectories(list(transitions), lstm)

    extra_kwargs = {"shuffle": True, "drop_last": True}
    if data_loader_kwargs is not None:
        extra_kwargs.update(data_loader_kwargs)
    
    return th_data.DataLoader(
        transitions,
        batch_size=batch_size,
        collate_fn=data.transitions_collate_fn,
        **extra_kwargs,
    )