import numpy as np
import torch
import collections
from typing import Optional, Union, Tuple, Dict

import scipy.special
from functools import reduce
from sklearn.linear_model import LinearRegression

import time
def dice_dataset(env, standardize_observation=True, absorbing_state=True, standardize_reward=True):
    """
    env: d4rl environment
    """
    dataset = env.get_dataset()
    N = dataset['rewards'].shape[0]
    initial_obs_, obs_, next_obs_, action_, reward_, done_ = [], [], [], [], [], []

    use_timeouts = ('timeouts' in dataset)

    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        is_final_timestep = dataset['timeouts'][i] if use_timeouts else (episode_step == env._max_episode_steps - 1)
        if is_final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue

        if episode_step == 0:
            initial_obs_.append(obs)

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1

        if done_bool or is_final_timestep:
            episode_step = 0

    initial_obs_dataset = {
        'initial_observations': np.array(initial_obs_, dtype=np.float32)
    }
    dataset = {
        'observations': np.array(obs_, dtype=np.float32),
        'actions': np.array(action_, dtype=np.float32),
        'next_observations': np.array(next_obs_, dtype=np.float32),
        'rewards': np.array(reward_, dtype=np.float32),
        'terminals': np.array(done_, dtype=np.float32)
    }
    dataset_statistics = {
        'observation_mean': np.mean(dataset['observations'], axis=0),
        'observation_std': np.std(dataset['observations'], axis=0),
        'reward_mean': np.mean(dataset['rewards']),
        'reward_std': np.std(dataset['rewards']),
        'N_initial_observations': len(initial_obs_),
        'N': len(obs_),
        'observation_dim': dataset['observations'].shape[-1],
        'action_dim': dataset['actions'].shape[-1]
    }

    if standardize_observation:
        initial_obs_dataset['initial_observations'] = (initial_obs_dataset['initial_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
        dataset['observations'] = (dataset['observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
        dataset['next_observations'] = (dataset['next_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
    if standardize_reward:
        dataset['rewards'] = (dataset['rewards'] - dataset_statistics['reward_mean']) / (dataset_statistics['reward_std'] + 1e-10)

    if absorbing_state:
        # add additional dimension to observations to deal with absorbing state
        initial_obs_dataset['initial_observations'] = np.concatenate((initial_obs_dataset['initial_observations'], np.zeros((dataset_statistics['N_initial_observations'], 1))), axis=1).astype(np.float32)
        dataset['observations'] = np.concatenate((dataset['observations'], np.zeros((dataset_statistics['N'], 1))), axis=1).astype(np.float32)
        dataset['next_observations'] = np.concatenate((dataset['next_observations'], np.zeros((dataset_statistics['N'], 1))), axis=1).astype(np.float32)
        terminal_indices = np.where(dataset['terminals'])[0]
        absorbing_state = np.eye(dataset_statistics['observation_dim'] + 1)[-1].astype(np.float32)
        dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
            list(dataset['observations']), list(dataset['actions']), list(dataset['rewards']), list(dataset['next_observations']), list(dataset['terminals'])
        for terminal_idx in terminal_indices:
            dataset['next_observations'][terminal_idx] = absorbing_state
            dataset['observations'].append(absorbing_state)
            dataset['actions'].append(dataset['actions'][terminal_idx])
            dataset['rewards'].append(0)
            dataset['next_observations'].append(absorbing_state)
            dataset['terminals'].append(1)

        dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
            np.array(dataset['observations'], dtype=np.float32), np.array(dataset['actions'], dtype=np.float32), np.array(dataset['rewards'], dtype=np.float32), \
            np.array(dataset['next_observations'], dtype=np.float32), np.array(dataset['terminals'], dtype=np.float32)
            
        dataset['init_observations'] = initial_obs_dataset['initial_observations']

    # return initial_obs_dataset, dataset, dataset_statistics
    return dataset, dataset_statistics


def qlearning_dataset(env, dataset=None, terminate_on_end=False,  **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    

    
    
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
        

    has_next_obs = True if 'next_observations' in dataset.keys() else False

    N = dataset['rewards'].shape[0]
    print('N', N)
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    init_obs_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True
    count = 0
    episode_step = 0
    init_flag = True
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        if init_flag:
            init_obs_.append(obs)

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            init_flag = True
            continue  
        if done_bool or final_timestep:
            episode_step = 0
            init_flag = True

            if not has_next_obs:
                continue
        else:
            init_flag = False

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1
        # print(len(init_obs_))

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
        'init_observations': np.array(init_obs_)
    }



def top_qlearning_dataset(env, dataset=None, terminate_on_end=False, ratio = 1.0, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.

    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().

    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    

    
    
    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    has_next_obs = True if 'next_observations' in dataset.keys() else False
    N = dataset['rewards'].shape[0]

    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    episode_return = 0
    traj_returns = []

    for i in range(N-1):

        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        episode_return += reward

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            traj_returns.append(episode_return)
            episode_return = 0
            continue  
        if done_bool or final_timestep:
            episode_step = 0
            traj_returns.append(episode_return)
            episode_return = 0

            if not has_next_obs:
                continue



    traj_returns.sort(reverse=True)
   
    critical_return = traj_returns[int(ratio *len(traj_returns))-1]
    
    start = 0
    tag = [False] * N
    
    for i in range(N-1):

        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        episode_return += reward

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            traj_returns.append(episode_return)
            episode_return = 0
            continue  
        if done_bool or final_timestep:
            episode_step = 0
            traj_returns.append(episode_return)
            episode_return = 0

            if not has_next_obs:
                continue    
    
    
    for i in range(N):
        if dataset['terminals'][i] or dataset['timeouts'][i]:
            
            if dataset['rewards'][start: i+1].sum() >= critical_return:
                tag[start:i+1] = [True] * len(tag[start:i+1])
            start = i+1
            if start >= N:
                break    
    

                 
    has_next_obs = True if 'next_observations' in dataset.keys() else False

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    init_obs_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True
    count = 0
    episode_step = 0
    init_flag = True
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        if init_flag:
            init_obs_.append(obs)

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            init_flag = True
            continue  
        if done_bool or final_timestep:
            episode_step = 0
            init_flag = True

            if not has_next_obs:
                continue
        else:
            init_flag = False

        if tag[i]:
            obs_.append(obs)
            next_obs_.append(new_obs)
            action_.append(action)
            reward_.append(reward)
            done_.append(done_bool)
        episode_step += 1
        
    print(len(obs_))

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
        'init_observations': np.array(init_obs_)
    }   





def RW(env, dataset=None, harness_alpha=0.1, terminate_on_end=False, **kwargs):

    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False
    print('has_next_obs', has_next_obs)
    N = dataset['rewards'].shape[0]


    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True
    count = 0
    episode_step = 0
    init_flag = True
    G, T, s0 = [], [], []
    g, t = 0, 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        if init_flag:
            s0.append(obs)

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            # print('~~~~~~~~count', count)
            init_flag = True
            G.append(g)
            T.append(t)
            g,t=0,0
            continue  

        if done_bool or final_timestep:
            # print('xxxxxxxx count', count)
            episode_step = 0
            init_flag = True
            G.append(g)
            T.append(t)
            g,t=0,0
            if not has_next_obs:
                continue
        else:
            
            init_flag = False

        episode_step += 1
        g+=reward
        t+=1
        count += 1
        if sum(T)+t != count:
            print(episode_step, t, sum(T)+t, count)
            break
    
    T.append(t)
    G.append(g)
    G = np.asarray(G)
    T = np.asarray(T)
    # G_it = np.asarray(reduce(lambda x, y: x + y, [[G_i] * T_i for G_i, T_i in zip(G, T)]))
    # s0 = np.stack(s0)
    # V = LinearRegression().fit(s0, G).predict(s0)
    # V_it = np.asarray(reduce(lambda x, y: x + y, [[V_i] * T_i for V_i, T_i in zip(V, T)]))
    # A_it = G_it - V_it
    # A_it = (A_it - A_it.min()) / (A_it.max() - A_it.min())
    # w_it = scipy.special.softmax(A_it / harness_alpha)
    # w_it /= w_it.sum() # Numerical errors
    G_it = np.asarray(reduce(lambda x, y: x + y, [[G_i] * T_i for G_i, T_i in zip(G, T)]))
    G_it = (G_it - G_it.min()) / (G_it.max() - G_it.min()+1e-6)
    w_it = scipy.special.softmax(G_it / harness_alpha)
    w_it = w_it /w_it.sum()

    return w_it












def AW(env, dataset=None, harness_alpha=0.1, terminate_on_end=False, **kwargs):

    if dataset is None:
        dataset = env.get_dataset(**kwargs)
    
    has_next_obs = True if 'next_observations' in dataset.keys() else False
    print('has_next_obs', has_next_obs)
    N = dataset['rewards'].shape[0]


    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True
    count = 0
    episode_step = 0
    init_flag = True
    G, T, s0 = [], [], []
    g, t = 0, 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        if has_next_obs:
            new_obs = dataset['next_observations'][i].astype(np.float32)
        else:
            new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        if init_flag:
            s0.append(obs)

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            # print('~~~~~~~~count', count)
            init_flag = True
            G.append(g)
            T.append(t)
            g,t=0,0
            continue  

        if done_bool or final_timestep:
            # print('xxxxxxxx count', count)
            episode_step = 0
            init_flag = True
            G.append(g)
            T.append(t)
            g,t=0,0
            if not has_next_obs:
                continue
        else:
            
            init_flag = False

        episode_step += 1
        g+=reward
        t+=1
        count += 1
        if sum(T)+t != count:
            print(episode_step, t, sum(T)+t, count)
            break
    
    T.append(t)
    G.append(g)
    G = np.asarray(G)
    T = np.asarray(T)
    G_it = np.asarray(reduce(lambda x, y: x + y, [[G_i] * T_i for G_i, T_i in zip(G, T)]))
    s0 = np.stack(s0)
    V = LinearRegression().fit(s0, G).predict(s0)
    V_it = np.asarray(reduce(lambda x, y: x + y, [[V_i] * T_i for V_i, T_i in zip(V, T)]))
    A_it = G_it - V_it
    A_it = (A_it - A_it.min()) / (A_it.max() - A_it.min()+1e-6)
    w_it = scipy.special.softmax(A_it / harness_alpha)
    w_it /= w_it.sum() # Numerical errors
    # G_it = np.asarray(reduce(lambda x, y: x + y, [[G_i] * T_i for G_i, T_i in zip(G, T)]))
    # G_it = (G_it - G_it.min()) / (G_it.max() - G_it.min())
    # w_it = scipy.special.softmax(G_it / harness_alpha)
    # w_it = w_it /w_it.sum()

    return w_it


def OPER_A(weight, weight_func='linear', exp_lambd=1.0, std=1.0, eps=0.0, eps_max=None):
    size = weight.shape[0]
    if weight_func == 'linear':
        weight = weight - weight.min()
        prob = weight / weight.sum()
        # keep mean, scale std
        if std:
            scale = std / (prob.std() * size)
            prob = scale*(prob - 1/size) + 1/size
            if eps: # if scale, the prob may be negative.
                prob = np.maximum(prob, eps/size)
            if eps_max: # if scale, the prob may be too large.
                prob = np.minimum(prob, eps_max/size)
        prob = prob/prob.sum() # norm to 1 again
    elif weight_func == 'exp':
        weight = weight / np.abs(weight).mean()
        weight = np.exp(exp_lambd * weight)
        prob = weight / weight.sum()

    return prob
  




class ReplayBuffer:
    def __init__(
        self,
        buffer_size: int,
        obs_shape: Tuple,
        obs_dtype: np.dtype,
        action_dim: int,
        action_dtype: np.dtype,
        device: str = "cpu"
    ) -> None:
        self._max_size = buffer_size
        self.obs_shape = obs_shape
        self.obs_dtype = obs_dtype
        self.action_dim = action_dim
        self.action_dtype = action_dtype

        self._ptr = 0
        self._size = 0

        self.observations = np.zeros((self._max_size,) + self.obs_shape, dtype=obs_dtype)
        self.next_observations = np.zeros((self._max_size,) + self.obs_shape, dtype=obs_dtype)
        self.actions = np.zeros((self._max_size, self.action_dim), dtype=action_dtype)
        self.rewards = np.zeros((self._max_size, 1), dtype=np.float32)
        self.terminals = np.zeros((self._max_size, 1), dtype=np.float32)
        

        self.device = torch.device(device)

    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        terminal: np.ndarray
    ) -> None:
        # Copy to avoid modification by reference
        self.observations[self._ptr] = np.array(obs).copy()
        self.next_observations[self._ptr] = np.array(next_obs).copy()
        self.actions[self._ptr] = np.array(action).copy()
        self.rewards[self._ptr] = np.array(reward).copy()
        self.terminals[self._ptr] = np.array(terminal).copy()

        self._ptr = (self._ptr + 1) % self._max_size
        self._size = min(self._size + 1, self._max_size)
    
    def add_batch(
        self,
        obss: np.ndarray,
        next_obss: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        terminals: np.ndarray
    ) -> None:
        batch_size = len(obss)
        indexes = np.arange(self._ptr, self._ptr + batch_size) % self._max_size

        self.observations[indexes] = np.array(obss).copy()
        self.next_observations[indexes] = np.array(next_obss).copy()
        self.actions[indexes] = np.array(actions).copy()
        self.rewards[indexes] = np.array(rewards).copy()
        self.terminals[indexes] = np.array(terminals).copy()

        self._ptr = (self._ptr + batch_size) % self._max_size
        self._size = min(self._size + batch_size, self._max_size)
    
    def load_dataset(self, dataset: Dict[str, np.ndarray], w = None) -> None:
        observations = np.array(dataset["observations"], dtype=self.obs_dtype)
        next_observations = np.array(dataset["next_observations"], dtype=self.obs_dtype)
        actions = np.array(dataset["actions"], dtype=self.action_dtype)
        rewards = np.array(dataset["rewards"], dtype=np.float32).reshape(-1, 1)
        terminals = np.array(dataset["terminals"], dtype=np.float32).reshape(-1, 1)
        init_observations = np.array(dataset["init_observations"], dtype=self.obs_dtype)

        self.observations = observations
        self.next_observations = next_observations
        self.actions = actions
        self.rewards = rewards
        self.terminals = terminals
        self.init_observations = init_observations
        self.w = w
        self.count = 0
        # print('self.rewards', self.rewards.shape, 'self.w', self.w.shape)

        self._ptr = len(observations)
        self._size = len(observations)
        self._init_size = len(init_observations)
     
    def normalize_obs(self, eps: float = 1e-3) -> Tuple[np.ndarray, np.ndarray]:
        mean = self.observations.mean(0, keepdims=True)
        std = self.observations.std(0, keepdims=True) + eps
        self.observations = (self.observations - mean) / std
        self.next_observations = (self.next_observations - mean) / std
        obs_mean, obs_std = mean, std
        return obs_mean, obs_std

    def sample(self, batch_size: int) -> Dict[str, torch.Tensor]:

        # batch_indexes = np.random.randint(0, self._size, size=batch_size)
        # batch_indexes_catch = np.random.choice(range(self._size), batch_size*100, p=self.w)
        if self.w is None:
            if self.count == 0:
                self.batch_indexes_catch = np.random.choice(range(self._size), batch_size*1000, p=self.w)
                batch_indexes = self.batch_indexes_catch[self.count*batch_size: (self.count+1)*batch_size]
                self.count+=1
            else:
                batch_indexes = self.batch_indexes_catch[self.count*batch_size: (self.count+1)*batch_size]
                self.count += 1
                if self.count>1000:
                    self.count = 0
        else:            
            batch_indexes = np.random.randint(0, self._size, size=batch_size)
        
            
            
        
        return {
            "observations": torch.tensor(self.observations[batch_indexes]).to(self.device),
            "actions": torch.tensor(self.actions[batch_indexes]).to(self.device),
            "next_observations": torch.tensor(self.next_observations[batch_indexes]).to(self.device),
            "terminals": torch.tensor(self.terminals[batch_indexes]).to(self.device),
            "rewards": torch.tensor(self.rewards[batch_indexes]).to(self.device),
        }
        
    def sample_init(self, batch_size: int) -> Dict[str, torch.Tensor]:

        batch_indexes = np.random.randint(0, self._init_size, size=batch_size)
        
        return {
            "init_observations": torch.tensor(self.init_observations[batch_indexes]).to(self.device)
        }
    
    def sample_all(self) -> Dict[str, np.ndarray]:
        return {
            "observations": self.observations[:self._size].copy(),
            "actions": self.actions[:self._size].copy(),
            "next_observations": self.next_observations[:self._size].copy(),
            "terminals": self.terminals[:self._size].copy(),
            "rewards": self.rewards[:self._size].copy()
        }