import torch
import numpy as np
import torch.nn as nn
import gym
import os
from collections import deque
import random
from torch.utils.data import Dataset, DataLoader
import time
from skimage.util.shape import view_as_windows

class eval_mode(object):
    def __init__(self, *models):
        self.models = models

    def __enter__(self):
        self.prev_states = []
        for model in self.models:
            self.prev_states.append(model.training)
            model.train(False)

    def __exit__(self, *args):
        for model, state in zip(self.models, self.prev_states):
            model.train(state)
        return False


def soft_update_params(net, target_net, tau):
    for param, target_param in zip(net.parameters(), target_net.parameters()):
        target_param.data.copy_(
            tau * param.data + (1 - tau) * target_param.data
        )


def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def module_hash(module):
    result = 0
    for tensor in module.state_dict().values():
        result += tensor.sum().item()
    return result


def make_dir(dir_path):
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


def preprocess_obs(obs, bits=5):
    """Preprocessing image, see https://arxiv.org/abs/1807.03039."""
    bins = 2**bits
    assert obs.dtype == torch.float32
    if bits < 8:
        obs = torch.floor(obs / 2**(8 - bits))
    obs = obs / bins
    obs = obs + torch.rand_like(obs) / bins
    obs = obs - 0.5
    return obs

class OneTrajectory():
    def __init__(self, obs_shape, action_shape, capacity=1000):
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
        self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.dones = np.empty((capacity, 1), dtype=np.float32)
        self.idx = 0

    def add(self, obs, action, reward, next_obs, done):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.dones[self.idx], done)

        self.idx = (self.idx + 1)
    
    def get_all(self):
        return self.obses[:self.idx],self.actions[:self.idx],self.rewards[:self.idx],self.next_obses[:self.idx], self.dones[:self.idx]

    def clear(self):
        self.idx=0



class ReplayBuffer(Dataset):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, action_shape, capacity, batch_size, device,image_size=84,transform=None,save_buffer=False,demo_ratio=0.25,demo_decay=True,self_imitate_num=0):
        self.capacity = capacity
        self.self_imitate_num = self_imitate_num
        self.batch_size = batch_size
        self.device = device
        self.image_size = image_size
        self.transform = transform
        self.last_save = 0
        # the proprioceptive obs is stored as float32, pixels obs as uint8
        obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
        
        self.obses = np.empty((capacity+self_imitate_num, *obs_shape), dtype=obs_dtype)
        self.next_obses = np.empty((capacity+self_imitate_num, *obs_shape), dtype=obs_dtype)
        self.actions = np.empty((capacity+self_imitate_num, *action_shape), dtype=np.float32)
        self.rewards = np.empty((capacity+self_imitate_num, 1), dtype=np.float32)
        self.not_dones = np.empty((capacity+self_imitate_num, 1), dtype=np.float32)

        self.idx = 0
        self.idx2 = 0
        self.full = False
        self.full2 =False

        self.obs_shape = obs_shape[0]
        self.action_shape = action_shape[0]

        self.protect_num = 0
        self.save_buffer=save_buffer
        self.demo_ratio=demo_ratio
        self.demo_ratio_origin=demo_ratio
        self.demo_decay=demo_decay
    
    def add(self, obs, action, reward, next_obs, done):
        np.copyto(self.obses[self.idx], obs)
        np.copyto(self.actions[self.idx], action)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.not_dones[self.idx], not done)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0
        if self.idx == 0:
            self.idx+=self.protect_num
        # if self.demo_decay and self.demo_ratio>0.1:
        #     self.demo_ratio -= self.demo_ratio_origin/30000

        if self.demo_decay and self.demo_ratio>0.001:
            self.demo_ratio -= self.demo_ratio_origin/300000

    def add_batch(self, obs, action, reward, next_obs, done): # for self imitate data
        size = done.shape[0]
        pre = self.capacity
        if self.idx2 + size <= self.self_imitate_num:
            np.copyto(self.obses[self.capacity+self.idx2:self.capacity+self.idx2+size], obs)
            np.copyto(self.actions[self.capacity+self.idx2:self.capacity+self.idx2+size], action)
            np.copyto(self.rewards[self.capacity+self.idx2:self.capacity+self.idx2+size], reward)
            np.copyto(self.next_obses[self.capacity+self.idx2:self.capacity+self.idx2+size], next_obs)
            np.copyto(self.not_dones[self.capacity+self.idx2:self.capacity+self.idx2+size], 1-done)
            if self.idx2 + size == self.self_imitate_num:
                self.full2=True
        else:
            self.full2 = True
            cut1=self.self_imitate_num-self.idx2
            cut2=self.idx2 + size-self.self_imitate_num
            np.copyto(self.obses[self.capacity+self.idx2:], obs[:cut1])
            np.copyto(self.actions[self.capacity+self.idx2:], action[:cut1])
            np.copyto(self.rewards[self.capacity+self.idx2:], reward[:cut1])
            np.copyto(self.next_obses[self.capacity+self.idx2:], next_obs[:cut1])
            np.copyto(self.not_dones[self.capacity+self.idx2:], 1-done[:cut1])
        self.idx2 = (self.idx2+size)%self.self_imitate_num

    def load_peg2(self, load_dir,norm=False,double=False,sqil=False,imitation_learning=False,r_lambda=1):
        data = np.load(load_dir)
        obs = data['arr_0']
        action = data['arr_1']
        r = data['arr_2']
        if imitation_learning:
            r=0
        if sqil is True:
            # r=r+1
            index = r>=10
            r = r +index*r_lambda+r_lambda
        next_o = data['arr_3']
        done = data['arr_4']
        done*=0
        num = done.shape[0]
        print('load_buffer_num:{}'.format(num))
        print('obs_shape:{}'.format(self.obses.shape))
        print('demo_ratio:{}, decay:{}'.format(self.demo_ratio,self.demo_decay))
        # b = data.reshape(1, -1, data.shape[2]).squeeze(0)
        mean=np.mean(obs,axis=0)
        std=np.std(obs,axis=0)
        if norm:
            obs =(obs-mean)/std
            next_o=(next_o-mean)/std
        if double:
            obs = np.concatenate((obs,obs),axis=1)
            next_o = np.concatenate((next_o,next_o),axis=1)
        np.copyto(self.obses[self.idx:self.idx+num],obs)
        np.copyto(self.actions[self.idx:self.idx+num],action)
        np.copyto(self.rewards[self.idx:self.idx+num], r)
        np.copyto(self.next_obses[self.idx:self.idx+num], next_o)
        np.copyto(self.not_dones[self.idx:self.idx+num], 1-done)
        
        self.idx += num
        self.protect_num = num

        mean=np.mean(obs,axis=0)
        std=np.std(obs,axis=0)
        return mean,std
    
    def sample_expert(self):
        idxs = np.random.randint(0, self.protect_num, size=(int(self.batch_size/2)))
        
        obses = self.obses[idxs]
        next_obses = self.next_obses[idxs]

        obses = torch.as_tensor(obses, device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(next_obses, device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        return obses, actions, rewards, next_obses, not_dones

    def sample_proprio(self):
        # online interact, demo, self imitatation
        idxs = np.random.randint(0, self.capacity if self.full else self.idx, size=self.batch_size)
        if self.protect_num > 0:
            size_demo=int(self.demo_ratio*self.batch_size)
            idxs[:size_demo] = np.random.randint(0, self.protect_num, size=size_demo)
        if self.self_imitate_num > 0:
            if self.full2 == False:
                size_self = min(int(self.demo_ratio*self.batch_size), int((self.idx2)))
            else:
                size_self = int(self.demo_ratio*self.batch_size)
            if size_self>0:
                idxs[-size_self:] = np.random.randint(self.capacity, self.capacity+self.self_imitate_num if self.full2 else self.capacity+self.idx2, size=size_self)
        
        obses = self.obses[idxs]
        next_obses = self.next_obses[idxs]

        obses = torch.as_tensor(obses, device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(next_obses, device=self.device).float()
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        return obses, actions, rewards, next_obses, not_dones

    def sample_cpc(self):

        start = time.time()
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )
      
        obses = self.obses[idxs]
        next_obses = self.next_obses[idxs]
        pos = obses.copy()

        obses = random_crop(obses, self.image_size)
        next_obses = random_crop(next_obses, self.image_size)
        pos = random_crop(pos, self.image_size)
    
        obses = torch.as_tensor(obses, device=self.device).float()
        next_obses = torch.as_tensor(
            next_obses, device=self.device
        ).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)

        pos = torch.as_tensor(pos, device=self.device).float()
        cpc_kwargs = dict(obs_anchor=obses, obs_pos=pos,
                          time_anchor=None, time_pos=None)

        return obses, actions, rewards, next_obses, not_dones, cpc_kwargs

    def save(self, save_dir):
        if self.idx == self.last_save:
            return
        path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
        payload = [
            self.obses[self.last_save:self.idx],
            self.next_obses[self.last_save:self.idx],
            self.actions[self.last_save:self.idx],
            self.rewards[self.last_save:self.idx],
            self.not_dones[self.last_save:self.idx]
        ]
        self.last_save = self.idx
        torch.save(payload, path)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chucks:
            start, end = [int(x) for x in chunk.split('.')[0].split('_')]
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.obses[start:end] = payload[0]
            self.next_obses[start:end] = payload[1]
            self.actions[start:end] = payload[2]
            self.rewards[start:end] = payload[3]
            self.not_dones[start:end] = payload[4]
            self.idx = end

    def __getitem__(self, idx):
        idx = np.random.randint(
            0, self.capacity if self.full else self.idx, size=1
        )
        idx = idx[0]
        obs = self.obses[idx]
        action = self.actions[idx]
        reward = self.rewards[idx]
        next_obs = self.next_obses[idx]
        not_done = self.not_dones[idx]

        if self.transform:
            obs = self.transform(obs)
            next_obs = self.transform(next_obs)

        return obs, action, reward, next_obs, not_done

    def __len__(self):
        return self.capacity 

class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        gym.Wrapper.__init__(self, env)
        self._k = k
        self._frames = deque([], maxlen=k)
        shp = (3,100,100)
        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=((shp[0] * k,) + shp[1:]),
            dtype=env.observation_space.dtype
        )
        self._max_episode_steps = env._max_episode_steps

    def reset(self):
        obs = self.env.reset()
        for _ in range(self._k):
            self._frames.append(obs)
        return self._get_obs()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self._frames.append(obs)
        return self._get_obs(), reward, done, info

    def _get_obs(self):
        assert len(self._frames) == self._k
        return np.concatenate(list(self._frames), axis=0)


def random_crop(imgs, output_size):
    """
    Vectorized way to do random crop using sliding windows
    and picking out random ones

    args:
        imgs, batch images with shape (B,C,H,W)
    """
    # batch size
    n = imgs.shape[0]
    img_size = imgs.shape[-1]
    crop_max = img_size - output_size
    imgs = np.transpose(imgs, (0, 2, 3, 1))
    w1 = np.random.randint(0, crop_max, n)
    h1 = np.random.randint(0, crop_max, n)
    # creates all sliding windows combinations of size (output_size)
    windows = view_as_windows(
        imgs, (1, output_size, output_size, 1))[..., 0,:,:, 0]
    # selects a random window for each batch element
    cropped_imgs = windows[np.arange(n), w1, h1]
    return cropped_imgs

def center_crop_image(image, output_size):
    h, w = image.shape[1:]
    new_h, new_w = output_size, output_size

    top = (h - new_h)//2
    left = (w - new_w)//2

    image = image[:, top:top + new_h, left:left + new_w]
    return image



        # if self.protect_num == 0 and self.self_imitate_num == 0:
        #     idxs = np.random.randint(
        #         0, self.capacity if self.full else self.idx, size=self.batch_size
        #     )
        # elif self.protect_num>0 and self.self_imitate_num == 0:
        #     idxs1 = np.random.randint(0, self.protect_num, size=int(self.batch_size*self.demo_ratio))
        #     idxs2 = np.random.randint(0, self.capacity if self.full else self.idx, size=self.batch_size-int(self.batch_size*self.demo_ratio))
        #     idxs = np.concatenate((idxs1, idxs2))

class value(nn.Module):
    """MLP for q-function."""
    def __init__(self, obs_dim, hidden_dim=32):
        super().__init__()

        self.trunk = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, obs):
        obs_action = self.trunk(obs)
        return self.trunk(obs_action)