import torch
import numpy as np
import os
import random
from torch.utils.data import Dataset

class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        # target_param.data.copy_(
        #     tau * param.data + (1 - tau) * target_param.data
        # )
        target_param.data = param.data * tau + (1 - tau) * target_param.data


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def module_hash(module):
    result = 0
    for tensor in module.state_dict().values():
        result += tensor.sum().item()
    return result


def make_dir(dir_path):
    try:
        os.makedirs(dir_path)
    except OSError:
        pass
    return dir_path



class ReplayBuffer(Dataset):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, action_shape, capacity, batch_size, device):
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device

        # the proprioceptive obs is stored as float32, pixels obs as uint8
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8

        self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity, 1), dtype=np.float32)
        self.not_done_maxs = np.empty((capacity, 1), dtype=np.float32)

        self.idx = 0
        self.last_save = 0
        self.full = False

    def add(self, obs, action, reward, next_obs, done, done_max):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)
        np.copyto(self.not_done_maxs[self.idx], not done_max)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def sample_proprio(self):
        
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )
        
        obses = self.obses[idxs]
        next_obses = self.next_obses[idxs]

        obses = torch.as_tensor(obses, device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(
            next_obses, device=self.device
        ).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        return obses, actions, rewards, next_obses, not_dones

    def sample_consecutive(self, t):
        """
        sample a consecutive minibatch data with time length tand acquire instance discrimination by randomly cropping
        """
        assert t > 1, 'timestep should be greater than 1'

        valid_idxs = np.array([], dtype=int)

        while not valid_idxs.size== self.batch_size:
            idxs = np.random.randint(
                0, self.capacity - t if self.full else self.idx - t, size=self.batch_size
            )
            invalid_index = np.array([], dtype=int)
            # allow     [1,1,..., 0]
            # time_step:[1,2,..., t]
            for i in range(t -1):
                idxs_tem = idxs + i
                not_dones = self.not_done_maxs[idxs_tem]
                invalid_index_tem = np.nonzero(not_dones == 0.0)[0]
                invalid_index = np.concatenate([invalid_index, invalid_index_tem])

            if invalid_index.size == 0:
                valid_idxs = idxs
            else:
                idxs = np.delete(idxs, invalid_index)
                required_size = self.batch_size - valid_idxs.size
                if required_size > idxs.size:
                    valid_idxs = np.concatenate([valid_idxs, idxs])
                else:
                    valid_idxs = np.concatenate([valid_idxs, idxs[:required_size]])


        sampled_obs = []
        sampled_actions = []

        # The not_done is always 1 since the done=0 is fake.
        _done_idx = valid_idxs

        for i in range(t):
            obses = self.obses[valid_idxs]
            actions = self.actions[valid_idxs]
            sampled_obs.append(np.expand_dims(obses, axis=0))
            sampled_actions.append(np.expand_dims(actions, axis=0))
            if i == t-1:
                next_obses = torch.as_tensor(self.next_obses[valid_idxs], device=self.device)
                rewards = torch.as_tensor(self.rewards[valid_idxs], device=self.device)
                not_dones = torch.as_tensor(self.not_dones[_done_idx], device=self.device)

            valid_idxs += 1

        obses = torch.as_tensor(np.concatenate(sampled_obs, axis=0), device=self.device).float()
        actions = torch.as_tensor(np.concatenate(sampled_actions, axis=0), device=self.device)

        return obses, actions, rewards, next_obses, not_dones

class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
