import numpy as np
import torch
import os
import random
import imageio
from tqdm import trange


class ReplayBuffer(object):
    def __init__(self, state_dim, action_dim, device, max_size=int(1e6)):
        self.max_size = max_size
        self.size = 0
        self.ptr = 0
        self.ptrD_ = 0

        self.state = np.empty((max_size, state_dim))
        self.action = np.empty((max_size, action_dim))
        self.next_state = np.empty((max_size, state_dim))
        self.reward = np.empty((max_size, 1))
        self.not_done = np.empty((max_size, 1))

        self.device = device

    def add(self, state, action, next_state, reward, done):
        self.state[self.ptr] = state
        self.action[self.ptr] = action
        self.next_state[self.ptr] = next_state
        self.reward[self.ptr] = reward
        self.not_done[self.ptr] = 1. - done
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def addD_(self, state, action):
        B = state.shape[0]
        if self.ptrD_ + B > self.max_size:
            over = (self.ptrD_ + B) - self.max_size
            self.state[self.ptrD_:self.max_size] = state[:B - over]
            self.state[self.size:self.size + over] = state[B - over:]
            self.action[self.ptrD_:self.max_size] = action[:B - over]
            self.action[self.size:self.size + over] = action[B - over:]
            self.ptrD_ = self.ptrD_ + B - self.max_size + self.size
        else:
            self.state[self.ptrD_:self.ptrD_ + B] = state
            self.action[self.ptrD_:self.ptrD_ + B] = action
            self.ptrD_ = self.ptrD_ + B

    def sample_vae(self, batch_size, mod):  # VAE training sample
        if mod == 'D_':
            ind = np.random.randint(self.size, self.max_size, size=batch_size)
        elif mod == 'DUD_':
            ind = np.random.randint(0, self.max_size, size=batch_size)
        elif mod == 'D':
            ind = np.random.randint(0, self.size, size=batch_size)
        return (
            torch.FloatTensor(self.state[ind]).to(self.device),
            torch.FloatTensor(self.action[ind]).to(self.device),
            ind,)

    def sample(self, batch_size, ind=None):  # RL training sample
        ind = np.random.randint(0, self.size, size=batch_size) if ind is None else ind
        return (
            torch.FloatTensor(self.state[ind]).to(self.device),
            torch.FloatTensor(self.action[ind]).to(self.device),
            torch.FloatTensor(self.next_state[ind]).to(self.device),
            torch.FloatTensor(self.reward[ind]).to(self.device),
            torch.FloatTensor(self.not_done[ind]).to(self.device),)

    def pre_locate(self, buffer_size):
        D = self.state.shape[0]
        num_repeats = buffer_size // D + 1
        state = np.repeat(self.state, num_repeats, axis=0)[:buffer_size]  # clip target size
        action = np.repeat(self.action, num_repeats, axis=0)[:buffer_size]

        self.state = np.concatenate([self.state, state], axis=0)
        self.action = np.concatenate([self.action, action], axis=0)
        self.max_size = self.state.shape[0]

    def convert_D4RL(self, dataset, standardize=False):
        self.state = dataset['observations']
        self.action = dataset['actions']
        self.next_state = dataset['next_observations']
        if standardize:
            self.reward = ((dataset['rewards'] - dataset['rewards'].min()) / (
                        dataset['rewards'].max() - dataset['rewards'].min())).reshape(-1, 1)
        else:
            self.reward = dataset['rewards'].reshape(-1, 1)
        self.not_done = 1. - dataset['terminals'].reshape(-1, 1)
        self.size = self.state.shape[0]
        self.ptrD_ = self.state.shape[0]

    def convert_D4RL_finetune(self, dataset):
        self.ptr = dataset['observations'].shape[0]
        self.size = dataset['observations'].shape[0]
        self.state[:self.ptr] = dataset['observations']
        self.action[:self.ptr] = dataset['actions']
        self.next_state[:self.ptr] = dataset['next_observations']
        self.reward[:self.ptr] = dataset['rewards'].reshape(-1, 1)
        self.not_done[:self.ptr] = 1. - dataset['terminals'].reshape(-1, 1)

    def normalize_states(self, eps=1e-3):
        mean = self.state[:self.size].mean(0, keepdims=True)
        std = self.state[:self.size].std(0, keepdims=True) + eps
        self.state = (self.state - mean) / std
        self.next_state = (self.next_state - mean) / std
        return mean, std

    def clip_to_eps(self, eps=1e-6): # due to tanh
        lim = 1 - eps
        self.action = np.clip(self.action, -lim, lim)


def make_dir(dir_path):
    try:
        os.mkdir(dir_path)
    except OSError:
        pass
    return dir_path


def set_seed_everywhere(env, args):
    env.seed(args.seed)
    env.action_space.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def snapshot_src(src, target, exclude_from):
    make_dir(target)
    os.system(f"rsync -rv --exclude-from={exclude_from} {src} {target}")


class VideoRecorder(object):
    def __init__(self, dir_name, height=512, width=512, camera_id=0, fps=30):
        self.dir_name = dir_name
        self.height = height
        self.width = width
        self.camera_id = camera_id
        self.fps = fps
        self.frames = []

    def init(self, enabled=True):
        self.frames = []
        self.enabled = self.dir_name is not None and enabled

    def record(self, env):
        if self.enabled:
            frame = env.sim.render(
                # mode='rgb_array',
                height=self.height,
                width=self.width,
                # camera_id=self.camera_id
            )
            self.frames.append(frame[::-1,:,:])

    def save(self, file_name):
        if self.enabled:
            path = os.path.join(self.dir_name, file_name)
            imageio.mimsave(path, self.frames, fps=self.fps)


def init_pi_beta(vae, vae_beta, vae_optimizer, policy, args):
    if not getattr(init_pi_beta, 'has_run', False):
        init_pi_beta.has_run = True
        vae_beta.load_state_dict(vae.state_dict())
        vae_beta.eval()
        for pam in vae_optimizer.param_groups:
            pam['lr'] = args.vae_lr
        policy.ood_min, policy.ood_max  = sum(policy.ood_min) / 200, sum(policy.ood_max) / 200
        args.ood = args.ood * policy.ood_max
        with open(os.path.join(args.work_dir, f'ood-{args.ood:.5f}_mu-{policy.ood_min:.5f}_max-{policy.ood_max:.5f}.json'),
                  'w', encoding='utf-8') as f:
            f.write(f' args.ood: {args.ood}\n policy.ood_min: {policy.ood_min}\n policy.ood_max: {policy.ood_max}\n')


def store_data(state, action, action_hat, Adv, vae_beta, replay_buffer, args, t, logger):
    neg_log_beta = vae_beta.elbo_loss(state, action_hat)
    pi_bool = (neg_log_beta < args.ood) & Adv.squeeze() if args.Adv else neg_log_beta < args.ood
    if args.mode == 2:
        state, action_t = state[pi_bool], action_hat[pi_bool]
    elif args.mode == 1:
        action_t = torch.where(pi_bool.unsqueeze(1).expand(-1, action.size(1)), action_hat, action)
    replay_buffer.addD_(state.cpu(), action_t.cpu())

    # Log
    logger.log('train/neg_log_beta', neg_log_beta.mean(), t)
    logger.log('train/neg_log_beta_max', neg_log_beta.max(), t)
    logger.log('train/buffer_ptr', replay_buffer.ptrD_, t)

