import numpy as np
import torch
import collections
from typing import Optional, Union, Tuple, Dict

import scipy.special
from functools import reduce
from sklearn.linear_model import LinearRegression

import pickle


def dice_dataset(env, standardize_observation=True, absorbing_state=True, standardize_reward=True):
    """
    env: d4rl environment
    """
    dataset = env.get_dataset()
    N = dataset['rewards'].shape[0]
    initial_obs_, obs_, next_obs_, action_, reward_, done_ = [], [], [], [], [], []

    use_timeouts = ('timeouts' in dataset)

    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        is_final_timestep = dataset['timeouts'][i] if use_timeouts else (episode_step == env._max_episode_steps - 1)
        if is_final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue

        if episode_step == 0:
            initial_obs_.append(obs)

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1

        if done_bool or is_final_timestep:
            episode_step = 0

    initial_obs_dataset = {
        'initial_observations': np.array(initial_obs_, dtype=np.float32)
    }
    dataset = {
        'observations': np.array(obs_, dtype=np.float32),
        'actions': np.array(action_, dtype=np.float32),
        'next_observations': np.array(next_obs_, dtype=np.float32),
        'rewards': np.array(reward_, dtype=np.float32),
        'terminals': np.array(done_, dtype=np.float32)
    }
    dataset_statistics = {
        'observation_mean': np.mean(dataset['observations'], axis=0),
        'observation_std': np.std(dataset['observations'], axis=0),
        'reward_mean': np.mean(dataset['rewards']),
        'reward_std': np.std(dataset['rewards']),
        'N_initial_observations': len(initial_obs_),
        'N': len(obs_),
        'observation_dim': dataset['observations'].shape[-1],
        'action_dim': dataset['actions'].shape[-1]
    }

    if standardize_observation:
        initial_obs_dataset['initial_observations'] = (initial_obs_dataset['initial_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
        dataset['observations'] = (dataset['observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
        dataset['next_observations'] = (dataset['next_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
    if standardize_reward:
        dataset['rewards'] = (dataset['rewards'] - dataset_statistics['reward_mean']) / (dataset_statistics['reward_std'] + 1e-10)

    if absorbing_state:
        # add additional dimension to observations to deal with absorbing state
        initial_obs_dataset['initial_observations'] = np.concatenate((initial_obs_dataset['initial_observations'], np.zeros((dataset_statistics['N_initial_observations'], 1))), axis=1).astype(np.float32)
        dataset['observations'] = np.concatenate((dataset['observations'], np.zeros((dataset_statistics['N'], 1))), axis=1).astype(np.float32)
        dataset['next_observations'] = np.concatenate((dataset['next_observations'], np.zeros((dataset_statistics['N'], 1))), axis=1).astype(np.float32)
        terminal_indices = np.where(dataset['terminals'])[0]
        absorbing_state = np.eye(dataset_statistics['observation_dim'] + 1)[-1].astype(np.float32)
        dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
            list(dataset['observations']), list(dataset['actions']), list(dataset['rewards']), list(dataset['next_observations']), list(dataset['terminals'])
        for terminal_idx in terminal_indices:
            dataset['next_observations'][terminal_idx] = absorbing_state
            dataset['observations'].append(absorbing_state)
            dataset['actions'].append(dataset['actions'][terminal_idx])
            dataset['rewards'].append(0)
            dataset['next_observations'].append(absorbing_state)
            dataset['terminals'].append(1)

        dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
            np.array(dataset['observations'], dtype=np.float32), np.array(dataset['actions'], dtype=np.float32), np.array(dataset['rewards'], dtype=np.float32), \
            np.array(dataset['next_observations'], dtype=np.float32), np.array(dataset['terminals'], dtype=np.float32)
            
        dataset['init_observations'] = initial_obs_dataset['initial_observations']

    # return initial_obs_dataset, dataset, dataset_statistics
    return dataset, dataset_statistics

def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    init_obs_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True
    count = 0
    episode_step = 0
    init_flag = True
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        if init_flag:
            init_obs_.append(obs)

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            init_flag = True
            continue  
        if done_bool or final_timestep:
            episode_step = 0
            init_flag = True

            if not has_next_obs:
                continue
        else:
            init_flag = False

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1
        # print(len(init_obs_))

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
        'init_observations': np.array(init_obs_)
    }








def AW(env, dataset=None, aw_alpha=0.1, terminate_on_end=False, **kwargs):

    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False
    print('has_next_obs', has_next_obs)
    N = dataset['rewards'].shape[0]


    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True
    count = 0
    episode_step = 0
    init_flag = True
    G, T, s0 = [], [], []
    g, t = 0, 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        if init_flag:
            s0.append(obs)

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            # print('~~~~~~~~count', count)
            init_flag = True
            G.append(g)
            T.append(t)
            g,t=0,0
            continue  

        if done_bool or final_timestep:
            # print('xxxxxxxx count', count)
            episode_step = 0
            init_flag = True
            G.append(g)
            T.append(t)
            g,t=0,0
            if not has_next_obs:
                continue
        else:
            
            init_flag = False

        episode_step += 1
        g+=reward
        t+=1
        count += 1
        if sum(T)+t != count:
            print(episode_step, t, sum(T)+t, count)
            break
    
    T.append(t)
    G.append(g)
    print(sum(T), count)
    G = np.asarray(G)
    T = np.asarray(T)
    G_it = np.asarray(reduce(lambda x, y: x + y, [[G_i] * T_i for G_i, T_i in zip(G, T)]))
    s0 = np.stack(s0)
    V = LinearRegression().fit(s0, G).predict(s0)
    V_it = np.asarray(reduce(lambda x, y: x + y, [[V_i] * T_i for V_i, T_i in zip(V, T)]))
    A_it = G_it - V_it
    A_it = (A_it - A_it.min()) / (A_it.max() - A_it.min())
    w_it = scipy.special.softmax(A_it / aw_alpha)
    w_it /= w_it.sum() # Numerical errors
    
    return w_it







class ReplayBuffer:
    def __init__(
        self,
        buffer_size: int,
        obs_shape: Tuple,
        obs_dtype: np.dtype,
        action_dim: int,
        action_dtype: np.dtype,
        device: str = "cpu"
    ) -> None:
        self._max_size = buffer_size
        self.obs_shape = obs_shape
        self.obs_dtype = obs_dtype
        self.action_dim = action_dim
        self.action_dtype = action_dtype

        self._ptr = 0
        self._size = 0

        self.observations = np.zeros((self._max_size,) + self.obs_shape, dtype=obs_dtype)
        self.next_observations = np.zeros((self._max_size,) + self.obs_shape, dtype=obs_dtype)
        self.actions = np.zeros((self._max_size, self.action_dim), dtype=action_dtype)
        self.rewards = np.zeros((self._max_size, 1), dtype=np.float32)
        self.terminals = np.zeros((self._max_size, 1), dtype=np.float32)

        self.device = torch.device(device)
        self.debug_count = 0

    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        terminal: np.ndarray
    ) -> None:
        # Copy to avoid modification by reference
        self.observations[self._ptr] = np.array(obs).copy()
        self.next_observations[self._ptr] = np.array(next_obs).copy()
        self.actions[self._ptr] = np.array(action).copy()
        self.rewards[self._ptr] = np.array(reward).copy()
        self.terminals[self._ptr] = np.array(terminal).copy()

        self._ptr = (self._ptr + 1) % self._max_size
        self._size = min(self._size + 1, self._max_size)
    
    def add_batch(
        self,
        obss: np.ndarray,
        next_obss: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        terminals: np.ndarray
    ) -> None:
        batch_size = len(obss)
        indexes = np.arange(self._ptr, self._ptr + batch_size) % self._max_size

        self.observations[indexes] = np.array(obss).copy()
        self.next_observations[indexes] = np.array(next_obss).copy()
        self.actions[indexes] = np.array(actions).copy()
        self.rewards[indexes] = np.array(rewards).copy()
        self.terminals[indexes] = np.array(terminals).copy()

        self._ptr = (self._ptr + batch_size) % self._max_size
        self._size = min(self._size + batch_size, self._max_size)
    
    def load_dataset(self, dataset: Dict[str, np.ndarray]) -> None:
        observations = np.array(dataset["observations"], dtype=self.obs_dtype)
        next_observations = np.array(dataset["next_observations"], dtype=self.obs_dtype)
        actions = np.array(dataset["actions"], dtype=self.action_dtype)
        rewards = np.array(dataset["rewards"], dtype=np.float32).reshape(-1, 1)
        terminals = np.array(dataset["terminals"], dtype=np.float32).reshape(-1, 1)
        init_observations = np.array(dataset["init_observations"], dtype=self.obs_dtype)

        self.observations = observations
        self.next_observations = next_observations
        self.actions = actions
        self.rewards = rewards
        self.terminals = terminals
        self.init_observations = init_observations

        self._ptr = len(observations)
        self._size = len(observations)
        self._init_size = len(init_observations)
     
    def normalize_obs(self, eps: float = 1e-3) -> Tuple[np.ndarray, np.ndarray]:
        mean = self.observations.mean(0, keepdims=True)
        std = self.observations.std(0, keepdims=True) + eps
        self.observations = (self.observations - mean) / std
        self.next_observations = (self.next_observations - mean) / std
        obs_mean, obs_std = mean, std
        return obs_mean, obs_std

    def sample(self, batch_size: int) -> Dict[str, torch.Tensor]:

        batch_indexes = np.random.randint(0, self._size, size=batch_size)
        # index_file_name = 'data/index_'+ str(self.debug_count) + '.pkl'
        # pickle.dump(batch_indexes, open(index_file_name, "wb"))
        # with open(index_file_name, 'rb') as f:
        #     batch_indexes= pickle.load(f)
        
        return {
            "observations": torch.tensor(self.observations[batch_indexes]).to(self.device),
            "actions": torch.tensor(self.actions[batch_indexes]).to(self.device),
            "next_observations": torch.tensor(self.next_observations[batch_indexes]).to(self.device),
            "terminals": torch.tensor(self.terminals[batch_indexes]).to(self.device),
            "rewards": torch.tensor(self.rewards[batch_indexes]).to(self.device)
        }
        
    def sample_init(self, batch_size: int) -> Dict[str, torch.Tensor]:

        batch_indexes = np.random.randint(0, self._init_size, size=batch_size)
        # index_file_name = 'data/index_init_'+ str(self.debug_count) + '.pkl'
        # pickle.dump(batch_indexes, open(index_file_name, "wb"))
 
        # with open(index_file_name, 'rb') as f:
        #     batch_indexes= pickle.load(f) 
        
        self.debug_count += 1
        return {
            "init_observations": torch.tensor(self.init_observations[batch_indexes]).to(self.device)
        }
    
    def sample_all(self) -> Dict[str, np.ndarray]:
        return {
            "observations": self.observations[:self._size].copy(),
            "actions": self.actions[:self._size].copy(),
            "next_observations": self.next_observations[:self._size].copy(),
            "terminals": self.terminals[:self._size].copy(),
            "rewards": self.rewards[:self._size].copy()
        }