from typing import (
    Any,
    Callable,
    Iterable,
    Iterator,
    Mapping,
    Optional,
    Tuple,
    Type,
    Union,
)

import gym
import numpy as np
import torch
from stable_baselines3.common import policies, utils

from stable_baselines3.common.preprocessing import get_flattened_obs_dim

from delphicORL.algos import base as algo_base
from delphicORL.algos.imitation.bc import *
from delphicORL.utils import data
from delphicORL.networks.vae_network import *
from delphicORL.algos.confounding.worldmodel import *
from delphicORL.networks.simple_q_func import Discrete_Q_func, Continuous_Q_func


def get_wm_inputdim(observation_space, action_space):
    if isinstance(observation_space, gym.spaces.Discrete):
        obs_dim = 1
    else:
        obs_dim = get_flattened_obs_dim(observation_space)
    if isinstance(action_space, gym.spaces.Discrete):
        action_dim = 1
    elif isinstance(action_space, gym.spaces.MultiDiscrete):
        action_dim = action_space.shape[0]
    else:
        action_dim = get_flattened_obs_dim(action_space)
    return obs_dim+action_dim
                   

class WorldModelLearner(algo_base.DemonstrationAlgorithm):
    def __init__(
        self,
        *,
        observation_space: gym.Space,
        action_space: gym.Space,
        wm_klweight=1e-2,
        wm_latent_dim = 32,
        wm_hidden_dims = [128, 64, 32],
        wm_target_dim = None,
        demonstrations = None,
        test_demonstrations = None,
        batch_size: int = 32,
        optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs = {},
        device = "cuda",
        custom_logger = None,
        lstm=False,
        max_len = 1000,
        no_train_q_func = False
    ):
        self._demo_data_loader = None
        self._test_demo_data_loader = None
        self.batch_size = batch_size
        self.lstm=lstm
        self.wm_latent_dim = wm_latent_dim
        self.wm_hidden_dims = wm_hidden_dims
        self.max_len = max_len
        
        super().__init__(
            custom_logger=custom_logger,
            demonstrations=demonstrations,
            test_demonstrations=test_demonstrations
        )

        self.action_space = action_space
        self.observation_space = observation_space
        
        if isinstance(action_space, gym.spaces.Discrete):
            action_dim = 1
        else:
            action_dim = action_space.shape[0]
            

        self.worldmodel = ComaptibleEpistemicWorldModels(kl_weight=wm_klweight,
                    input_dim = get_wm_inputdim(self.observation_space, action_space),
                    latent_dim = wm_latent_dim,
                    max_len=max_len,
                    hidden_size = wm_hidden_dims,
                    action_dim=action_dim)


        self.worldmodel = self.worldmodel.to(utils.get_device(device))


        # NORMALISE STATES AND ACTIONS
        self.statistics = data.compute_norm_statistics(self._demo_data_loader)
        self.set_demonstrations(data.normalise(self._demo_data_loader, self.statistics))
        if test_demonstrations is not None:
            self.set_test_demonstrations(data.normalise(self._test_demo_data_loader, self.statistics))

        if not no_train_q_func:
            if isinstance(observation_space, gym.spaces.Discrete) and isinstance(action_space, gym.spaces.Discrete):
                qf = Discrete_Q_func(self._demo_data_loader)
            else:
                qf = Continuous_Q_func(self._demo_data_loader, get_wm_inputdim(self.observation_space, action_space))
            self.set_demonstrations(self.label_trajectories_w_func(self._demo_data_loader, qf))
            self.set_test_demonstrations(self.label_trajectories_w_func(self._test_demo_data_loader, qf))
        else:
            print("Make sure to relabel trajectories with wm_target before training!")

        self.optimizer = optimizer_cls(
            self.worldmodel.parameters(),
            **optimizer_kwargs,
        )


        
    def trainer(self, batch):
        obs = torch.as_tensor(batch["obs"], device='cuda').detach().to(torch.float32)
        acts = torch.as_tensor(batch["acts"], device='cuda').detach().to(torch.float32)
        obs = algo_base.unsqueezed_array(obs, self.lstm)
        acts = algo_base.unsqueezed_array(acts, self.lstm)

        loss_kw = {}
        loss_kw["qfunc_target_seq"] = algo_base.unsqueezed_array(
                torch.as_tensor(batch["vaetarget"], device='cuda').detach().to(torch.float32), self.lstm)
        if self.lstm:
            loss_kw["mask"] = torch.as_tensor(batch["masks"], device='cuda').detach().to(torch.bool)


        loss_metrics = {}
        for k in range(len(self.worldmodel.models)):
            loss_metrics.update(self.worldmodel.loss(states=obs, actions=acts, model_k=k, **loss_kw))
            self.optimizer.zero_grad()
            loss_metrics[f'WM{k}_loss'].backward()
            self.optimizer.step()

        return loss_metrics

    def label_trajectories_w_func(self, dataloader, func):
        transitions = data.dataclass_quick_asdict(dataloader.dataset)

        with torch.no_grad():
            obs = torch.as_tensor(transitions["obs"]).detach().to(torch.float32) #.cuda()
            acts = torch.as_tensor(transitions["acts"]).detach().to(torch.float32) #.cuda()
            obs = algo_base.unsqueezed_array(obs, self.lstm)
            acts = algo_base.unsqueezed_array(acts, self.lstm)
            wm_target_vals = func(x=obs, action=acts)
            transitions["vaetarget"] = wm_target_vals.detach().cpu().numpy()

            return data.Transitions(**transitions)



    def train(
        self,
        *,
        n_epochs: Optional[int] = None,
        log_interval: int = 500,
    ):
        batches_with_stats = self.setup_training(n_epochs=n_epochs)

        self.worldmodel.train()

        for (batch_num, batch_size, num_samples_so_far), batch in batches_with_stats:
            loss = self.trainer(batch)

            if batch_num % log_interval == 0:

                self._logger.log_batch(
                    batch_num,
                    batch_size,
                    num_samples_so_far,
                    loss
                )

    def save_clvae(self, wm_path) -> None:
        torch.save(self.worldmodel, wm_path)

    def save_vae(self, wm_path): #overload
        self.save_clvae(wm_path)


def augmented_obs_space(observation_space, context_dim):
    context_low = np.array([-np.inf] * context_dim)
    context_high = np.array([np.inf] * context_dim)
    if isinstance(observation_space, gym.spaces.Box):
        return gym.spaces.Box(low = np.concatenate([observation_space.low, context_low], -1),
                            high = np.concatenate([observation_space.high, context_high], -1)
                    )
    else:
        return gym.spaces.Tuple(
                    (observation_space, 
                    gym.spaces.Box(low =context_low, high=context_high)))




class BC_wDelphicPenalty(BC):
    def __init__(
        self,
        *,
        observation_space: gym.Space,
        action_space: gym.Space,
        policy: Optional[policies.ActorCriticPolicy] = None,
        wm_klweight=1e-2,
        wm_latent_dim = 32,
        wm_hidden_dims = [128, 64, 32],
        wm_target_dim = None,
        demonstrations = None,
        test_demonstrations = None,
        batch_size: int = 32,
        optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Mapping[str, Any]] = {},
        ent_weight: float = 1e-3,
        l2_weight: float = 0.0,
        device: Union[str, torch.device] = "auto",
        custom_logger = None,
        lstm=False,
        lstm_model = 'lstm',
        max_len=1000,
    ):
        super().__init__(
            observation_space=observation_space, 
            action_space=action_space,
            policy=policy,
            demonstrations=demonstrations,
            test_demonstrations=test_demonstrations,
            batch_size=batch_size,
            optimizer_cls=optimizer_cls,
            optimizer_kwargs=optimizer_kwargs,
            ent_weight=ent_weight,
            l2_weight=l2_weight,
            device=device,
            custom_logger=custom_logger,
            lstm=lstm,
            lstm_model = lstm_model,
            input_name = 'obs')

        self.wmlearner = WorldModelLearner(
            observation_space=observation_space,
            action_space=action_space,
            wm_klweight=wm_klweight,
            wm_latent_dim = wm_latent_dim,
            wm_hidden_dims=wm_hidden_dims,
            wm_target_dim=wm_target_dim,
            demonstrations=demonstrations,
            test_demonstrations=test_demonstrations,
            batch_size=batch_size,
            optimizer_cls=optimizer_cls,
            optimizer_kwargs=optimizer_kwargs,
            custom_logger=custom_logger,
            lstm=lstm,
            max_len=max_len
        )

    def loss(self,
        policy: policies.ActorCriticPolicy,
        vae,
        obs: Union[torch.Tensor, np.ndarray],
        acts: Union[torch.Tensor, np.ndarray],
        masks = None
    ) -> dict:       
        if masks is not None:
            _, log_prob, entropy = policy.evaluate_actions(obs, acts, masks)
            obs = algo_base.unsqueezed_array(obs, True)
            acts = algo_base.unsqueezed_array(acts, True)
            uncertainty = vae.predict_delphic_uncertainty(obs, acts, masks, policy = self.policy)
        else:
            _, log_prob, entropy = policy.evaluate_actions(obs, acts)
            obs = algo_base.unsqueezed_array(obs, False)
            acts = algo_base.unsqueezed_array(acts, False)
            uncertainty = vae.predict_delphic_uncertainty(obs, acts, policy = self.policy)
        prob_true_act = torch.exp(log_prob).mean()
        log_prob = (log_prob / (uncertainty+1e-6)).mean()
        entropy = entropy.mean()

        ent_loss = -self.ent_weight * entropy
        neglogp = -log_prob
        #l2_loss = self.l2_weight * l2_norm

        return dict(
            neglogp=neglogp,
            uncertainty=uncertainty.mean(),
            entropy=entropy,
            ent_loss=ent_loss,
            prob_true_act=prob_true_act,
            loss=neglogp + ent_loss #+ l2_loss,
        )

    def trainer(self, batch) -> dict:

        input_output = algo_base.get_bc_input_output(batch, self.input_name, device= self.policy.device, lstm=self.lstm)
        self.policy.train()
        self.wmlearner.worldmodel.eval()
        training_metrics = self.loss(self.policy, self.wmlearner.world_model, *input_output)

        self.optimizer.zero_grad()
        training_metrics['loss'].backward()
        self.optimizer.step()

        return training_metrics


    def train_contextlearner(self, **cl_train_kwargs):
        self.wmlearner.train(**cl_train_kwargs)