# Utilities for diffusion.
from typing import Optional, List, Union

import d4rl
import gin
import gym
import numpy as np
import torch
from torch import nn

# GIN-required Imports.
from diffusion.elucidated_diffusion import ElucidatedDiffusion
from diffusion.norm import normalizer_factory


# Make transition dataset from data.
@gin.configurable
def make_inputs(
        env: gym.Env,
        modelled_terminals: bool = False,
) -> np.ndarray:
    dataset = d4rl.qlearning_dataset(env)
    obs = dataset['observations']
    actions = dataset['actions']
    next_obs = dataset['next_observations']
    rewards = dataset['rewards']
    inputs = np.concatenate([obs, actions, rewards[:, None], next_obs], axis=1)
    if modelled_terminals:
        terminals = dataset['terminals'].astype(np.float32)
        inputs = np.concatenate([inputs, terminals[:, None]], axis=1)
    return inputs


# Convert diffusion samples back to (s, a, r, s') format.
@gin.configurable
def split_diffusion_samples(
        samples: Union[np.ndarray, torch.Tensor],
        env: gym.Env,
        modelled_terminals: bool = False,
        terminal_threshold: Optional[float] = None,
):
    # Compute dimensions from env
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    # Split samples into (s, a, r, s') format
    obs = samples[:, :obs_dim]
    actions = samples[:, obs_dim:obs_dim + action_dim]
    rewards = samples[:, obs_dim + action_dim]
    next_obs = samples[:, obs_dim + action_dim + 1: obs_dim + action_dim + 1 + obs_dim]
    if modelled_terminals:
        terminals = samples[:, -1]
        if terminal_threshold is not None:
            if isinstance(terminals, torch.Tensor):
                terminals = (terminals > terminal_threshold).float()
            else:
                terminals = (terminals > terminal_threshold).astype(np.float32)
        return obs, actions, rewards, next_obs, terminals
    else:
        return obs, actions, rewards, next_obs

@gin.configurable
def split_diffusion_samples_no_sa(
        samples: Union[np.ndarray, torch.Tensor],
        env: gym.Env,
        modelled_terminals: bool = False,
        terminal_threshold: Optional[float] = None,
):
    # Compute dimensions from env
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    # Split samples into (s, a, r, s') format
    rewards = samples[:, obs_dim+action_dim : obs_dim+action_dim+1]
    next_obs = samples[:, obs_dim+action_dim+1 : 2*obs_dim+action_dim+1]
    if modelled_terminals:
        terminals = samples[:, -1]
        if terminal_threshold is not None:
            if isinstance(terminals, torch.Tensor):
                terminals = (terminals > terminal_threshold).float()
            else:
                terminals = (terminals > terminal_threshold).astype(np.float32)
        return rewards, next_obs, terminals
    else:
        return rewards, next_obs

@gin.configurable
def split_trajectory_diffusion_samples(
        samples: Union[np.ndarray, torch.Tensor],
        env: gym.Env,
        seq_len: int,
        output_seq: bool = False,
        modelled_terminals: bool = False,
        terminal_threshold: Optional[float] = None,
):
    # Compute dimensions from env
    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    # Split samples into (s, a, r, s') format
    if seq_len == 1:
        obs = samples[:, :obs_dim]
        actions = samples[:, obs_dim:obs_dim + action_dim]
        rewards = samples[:, obs_dim + action_dim]
        next_obs = samples[:, obs_dim + action_dim + 1: obs_dim + action_dim + 1 + obs_dim]
        if modelled_terminals:
            terminals = samples[:, -1]
            if terminal_threshold is not None:
                if isinstance(terminals, torch.Tensor):
                    terminals = (terminals > terminal_threshold).float()
                else:
                    terminals = (terminals > terminal_threshold).astype(np.float32)
            return obs.unsqueeze(1), actions.unsqueeze(1), rewards.unsqueeze(1), \
                next_obs.unsqueeze(1), terminals.unsqueeze(1)
        else:
            return obs.unsqueeze(1), actions.unsqueeze(1), rewards.unsqueeze(1), next_obs.unsqueeze(1)
    else:
        event_shape = obs_dim + action_dim + 1
        if modelled_terminals:
            event_shape += 1
        aug_data = samples.reshape(-1, seq_len, event_shape)
        if output_seq:
            obs = aug_data[:, :, :obs_dim]
            actions = aug_data[:, :, obs_dim: obs_dim+action_dim]
            rewards = aug_data[:, :, obs_dim+action_dim]
            if modelled_terminals:
                terminals = aug_data[:, :, -1]
                if terminal_threshold is not None:
                    if isinstance(terminals, torch.Tensor):
                        terminals = (terminals > terminal_threshold).float()
                    else:
                        terminals = (terminals > terminal_threshold).astype(np.float32)
                return obs, actions, rewards, terminals
            else:
                return obs, actions, rewards
        else:
            obs = aug_data[:, :-1, :obs_dim]
            actions = aug_data[:, :-1, obs_dim: obs_dim+action_dim]
            rewards = aug_data[:, :-1, obs_dim+action_dim]
            next_obs = aug_data[:, 1: , :obs_dim]
            if modelled_terminals:
                terminals = aug_data[:, :-1, -1]
                if terminal_threshold is not None:
                    if isinstance(terminals, torch.Tensor):
                        terminals = (terminals > terminal_threshold).float()
                    else:
                        terminals = (terminals > terminal_threshold).astype(np.float32)
                return obs, actions, rewards, next_obs, terminals
            else:
                return obs, actions, rewards, next_obs

@gin.configurable
def construct_diffusion_model(
        inputs: torch.Tensor,
        normalizer_type: str,
        denoising_network: nn.Module,
        cond_normalizer_type='minmax',
        disable_terminal_norm: bool = False,
        skip_dims: List[int] = [],
        cond_dim: Optional[int] = None,
        added_dims: List[int] = None,
        output_dim: Optional[int] = None,
        no_cond: bool = False,
        args=None
) -> ElucidatedDiffusion:
    event_dim = inputs.shape[1]
    model = denoising_network(d_in=event_dim, 
                              cond_dim=cond_dim, 
                              output_dim=output_dim,
                              cfg_dropout=0.25,
                              no_cond=no_cond)
    print(normalizer_type)

    if disable_terminal_norm:
        terminal_dim = event_dim - 1
        if terminal_dim not in skip_dims:
            skip_dims.append(terminal_dim)

    if skip_dims:
        print(f"Skipping normalization for dimensions {skip_dims}.")

    normalizer = normalizer_factory(normalizer_type, inputs, 
                                    skip_dims=skip_dims, 
                                    added_dims=added_dims)
    cond_normalizer = normalizer_factory(cond_normalizer_type, inputs, 
                                         skip_dims=skip_dims, 
                                         added_dims=added_dims)

    return ElucidatedDiffusion(
        net=model,
        normalizer=normalizer,
        cond_normalizer=cond_normalizer,
        event_shape=[event_dim],
        args=args,
    )


