import gym
import torch
import numpy as np
import d4rl
import pickle
import gc
import random

def cal_vec_cos(vec1, vec2):
    assert vec1.shape==vec2.shape
    return np.sum(vec1*vec2, axis=-1) / (np.linalg.norm(vec1, axis=-1)*np.linalg.norm(vec2, axis=-1))

def rotate_vectors(v1, v2, shrink_ratio):

    if np.abs(shrink_ratio-1)<0.001:
        return v2 / np.linalg.norm(v2, ord=1, axis=1, keepdims=True)
    
    if np.abs(shrink_ratio)<0.001:
        return v1 / np.linalg.norm(v1, ord=1, axis=1, keepdims=True)

    v1 = v1 / np.linalg.norm(v1, ord=2, axis=1, keepdims=True)
    vec_cos_old = cal_vec_cos(v1, v2)
    vec_cos_new = np.cos(shrink_ratio*np.arccos(np.clip(vec_cos_old, -0.99999, 0.99999)))
    
    vec_cos_new_square = np.square(vec_cos_new)
    v1v2 = np.einsum("ij,ij->i", v1, v2)
    v1v1 = np.einsum("ij,ij->i", v1, v1)
    v2v2 = np.einsum("ij,ij->i", v2, v2)
    a=np.square(v1v2)-v2v2*vec_cos_new_square
    b=2*v1v2*v1v1-2*v1v2*vec_cos_new_square
    c=np.square(v1v1)-v1v1*vec_cos_new_square
    discriminant = np.maximum(b**2 - 4*a*c, 0)

    root1 = (-b + np.sqrt(discriminant)) / (2*a+1e-8)
    root2 = (-b - np.sqrt(discriminant)) / (2*a+1e-8)
    root = np.maximum(root1, root2).reshape(-1, 1)
    
    new_vec = v1 + root * v2
    zero_vec_idx = np.all(np.abs(new_vec)<1e-4, 1)
    new_vec[zero_vec_idx] = v1[zero_vec_idx]
    new_vec = new_vec / np.linalg.norm(new_vec, ord=1, axis=1, keepdims=True)
    return new_vec

class Preprocessor:
    def __init__(self, args):
        self.args = args
        self.reward_scale_weight, self.max_cost_return, self.mean, self.std = None, None, 0.0, 1.0

    def __call__(self, trajectories, env, dataset_type, is_test_dataset=False):
        # clip the action value 
        for traj in trajectories:
            traj['actions'] = np.clip(traj['actions'], -self.args.max_action, self.args.max_action)

        # preprocess the safe dataset: align the scales of reward and cost, then invert the cost
        safe_obj_idx = np.where(env.safe_obj_list)[0]
        obj_idx = np.where(1-env.safe_obj_list)[0]
        if not is_test_dataset:
            self.reward_scale_weight = np.abs(np.mean(np.concatenate([x['raw_rewards'] for x in trajectories], axis=0), axis=0)) 
            self.reward_scale_weight[safe_obj_idx] *= -1.0
            self.max_cost_return = np.max(np.array([np.sum(x['raw_rewards'][:, safe_obj_idx], axis=0) for x in trajectories]), axis=0) 
        for traj in trajectories:
            traj['raw_rewards'] = traj['raw_rewards'] / self.reward_scale_weight.reshape(1, -1)
            ret = np.sum(traj['raw_rewards'], axis=0)
            ret[safe_obj_idx] = -self.max_cost_return/self.reward_scale_weight[safe_obj_idx] + ret[safe_obj_idx] 
            traj['preference'] = (ret / np.linalg.norm(ret, ord=1)).reshape(1, -1).repeat(len(traj['raw_rewards']), 0)

        # exclude the trajectories that had been selected as expert demo in the test datasets
        if not is_test_dataset and self.args.dataset_preprocess=='exclude_test_data' and  not (dataset_type=='amateur_uniform'):
            test_dataset = env.get_test_dataset()
            selected_idx = []
            for x in test_dataset:
                selected_idx.extend(x['selected_idx'])
            trajectories = [traj for i, traj in enumerate(trajectories) if i not in selected_idx]

        if self.args.normalize_states:
            if not is_test_dataset:
                state = np.concatenate([x['observations'] for x in trajectories], axis=0)
                self.mean = state.mean(0,keepdims=True)
                self.std = state.std(0,keepdims=True) + 1e-30
            for traj in trajectories:
                traj['observations'] = (traj['observations'] - self.mean)/self.std  
                traj['next_observations'] = (traj['next_observations'] - self.mean)/self.std 

        return trajectories


class ReplayBuffer(object):
    def __init__(self, args, max_size=int(2e6), ):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        self.device = args.device
        self.args = args
        self.mean, self.std = 0, 1
        self.candidate_ind = None

    def fix_sample_ind(self, candidate_ind):
        self.candidate_ind = candidate_ind

    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size) if self.candidate_ind is None else self.candidate_ind
        rewards = self.reward[ind]
        preference = self.gen_random_perturbed_preference(self.ori_preference[ind])
        weight_num = 1
        
        # if self.args.pref_perturb_theta!=0 and self.args.algo=='IQL': 
        #     weight_num = self.args.weight_num
        #     preference_idx = np.random.choice(len(preference), size=weight_num, replace=False)
        #     preference = np.tile(preference[preference_idx], (batch_size, 1))
        # else:
        #     weight_num = 1

        return (
            torch.FloatTensor(self.state[ind].repeat(weight_num, axis=0)).to(self.device),
            torch.FloatTensor(self.action[ind].repeat(weight_num, axis=0)).to(self.device),
            torch.FloatTensor(self.next_state[ind].repeat(weight_num, axis=0)).to(self.device),
            torch.FloatTensor(rewards.repeat(weight_num, axis=0)).to(self.device),
            torch.FloatTensor(self.not_done[ind].repeat(weight_num, axis=0)).to(self.device),
            torch.FloatTensor(preference).to(self.device),
        )
        
    def gen_random_preference(self, size):
        w_batch_rnd = np.random.randn(size, self.args.reward_size)
        w_batch_obj = np.abs(w_batch_rnd) / np.linalg.norm(w_batch_rnd, ord=1, axis=1, keepdims=True)
        w_batch = w_batch_obj
        return w_batch_obj, w_batch
    
    def gen_random_perturbed_preference(self, ori_prefs):
        size = len(ori_prefs) 
        w_batch_rnd = np.random.randn(size, self.args.reward_size)
        w_batch_obj = w_batch_rnd / np.linalg.norm(w_batch_rnd, ord=1, axis=1, keepdims=True)
        w_batch_obj = rotate_vectors(ori_prefs, w_batch_obj, self.args.pref_perturb_theta)
        w_batch_obj = np.abs(w_batch_obj) / np.linalg.norm(w_batch_obj, ord=1, axis=1, keepdims=True) 
        if self.args.fixed_pref1 is not None:
            w_batch_obj[:] = self.args.fixed_pref1.reshape(1, -1)
        return w_batch_obj 

    def load_from_dataset(self, env, dataset):
        max_episode_len = self.args.max_episode_len
        trajectories = dataset

        self.size = np.sum([len(traj['observations']) for traj in trajectories])
        self.state = np.zeros((self.size, len(trajectories[0]['observations'][0])))
        self.next_state = np.zeros((self.size, len(trajectories[0]['observations'][0])))
        self.action = np.zeros((self.size, len(trajectories[0]['actions'][0])))
        self.reward = np.zeros((self.size, len(trajectories[0]['raw_rewards'][0])))
        self.ori_preference = np.zeros((self.size, len(trajectories[0]['preference'][0])))
        self.not_done = np.ones((self.size, 1), dtype=bool)
        returns = []
        self.pref_ret_pairs = []
        gamma = self.args.gamma**np.expand_dims(np.arange(max_episode_len), 1)
        cnt = 0
        for traj in trajectories:
            traj_len = len(traj['observations'])
            self.state[cnt:cnt+traj_len] = traj['observations']
            self.next_state[cnt:cnt+traj_len] = traj['next_observations']
            self.action[cnt:cnt+traj_len] = traj['actions']
            self.reward[cnt:cnt+traj_len] = traj['raw_rewards']
            self.ori_preference[cnt:cnt+traj_len] = traj['preference']
            self.not_done[cnt+traj_len-1] = (traj_len==max_episode_len)
            returns.append(np.sum(self.reward[cnt:cnt+traj_len]*gamma[:traj_len], 0))
            self.pref_ret_pairs.append((traj['preference'][0], np.sum(self.reward[cnt:cnt+traj_len], 0)))
            cnt += traj_len

        min_pref, max_pref = np.min(self.ori_preference, axis=0), np.max(self.ori_preference, axis=0)
        self.min_pref, self.max_pref = min_pref, max_pref
        cnt = 0
        self.v_min, self.v_max = np.min(returns, 0, keepdims=True), np.max(returns, 0, keepdims=True)
        self.v_mean, self.v_std = np.mean(returns, 0, keepdims=True), np.std(returns, 0, keepdims=True)

        print("num of original trajectory: ", len(trajectories))
        print("original pref range: ", min_pref, max_pref)
        print('episode_num:', len(returns))
        print('min_max:', self.v_min, self.v_max)
        print('mean, std:', self.v_mean, self.v_std)

    
    

class TrajReplayBuffer():
    
    def __init__(
        self,
        args,
        avg_rtg = True,
    ):
        self.max_len = 1
        self.max_ep_len = args.max_episode_len
        self.state_dim = args.obs_shape + args.reward_size
        self.act_dim = args.action_shape
        self.pref_dim = args.reward_size
        self.rtg_dim = args.reward_size
        self.scale = 1.0
        self.device = args.device
        self.avg_rtg = avg_rtg
        self.gamma = 1.0
        
    def load_from_dataset(self, env, dataset):  
        self.trajectories = dataset
        traj_lens = np.array([len(traj['observations']) for traj in dataset])
        self.p_sample = traj_lens / sum(traj_lens)
        for traj in self.trajectories:
            traj['observations'] = np.concatenate((traj['observations'], np.tile(traj['preference'], 1)), axis=1)
        
            
        
    def discount_cumsum(self, x):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0]-1)):
            discount_cumsum[t] = x[t] + self.gamma * discount_cumsum[t+1]
        return discount_cumsum

    def discount_cumsum_mo(self, x_mo):
        return np.transpose(np.array([self.discount_cumsum(x_mo[:,i]) for i in range(x_mo.shape[1])]))

    
    def find_avg_rtg(self, x):
        return np.mean(x)

    def find_avg_rtg_mo(self, x_mo):
        return np.mean(x_mo, axis=0)

    def sample(self, batch_size):
        batch_inds = np.random.choice(
            np.arange(len(self.trajectories)),
            size=batch_size,
            replace=True,
            p=self.p_sample,
        )
        s, a, pref, rtg, timesteps, mask = [], [], [], [], [], []
        raw_r = []
        for i in batch_inds:
            # randomly get the traj from all trajectories
            traj = self.trajectories[i]
            # randomly get the starting idx
            step_start = random.randint(0, traj['observations'].shape[0] - 1)
            step_end = step_start + self.max_len

            s.append(traj['observations'][step_start:step_end].reshape(1, -1, self.state_dim))
            a.append(traj['actions'][step_start:step_end].reshape(1, -1, self.act_dim)) # assume scale if relflective to 0 (-x, x)
            raw_r_to_add = traj['raw_rewards'][step_start:step_end].reshape(1, -1, self.pref_dim)
            raw_r.append(raw_r_to_add)
            pref.append(traj['preference'][step_start:step_end].reshape(1, -1, self.pref_dim))
            timesteps.append(np.arange(step_start, step_start + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len-1  # padding cutoff
            
            # non-rvs: use discount cumsum
            if not self.avg_rtg:
                if self.rtg_dim == 1:
                    rtg.append(self.discount_cumsum(traj['rewards'][step_start:step_end]).reshape(1, -1, self.rtg_dim))
                else:
                    rtg.append(self.discount_cumsum_mo(traj['raw_rewards'][step_start:step_end]).reshape(1, -1, self.rtg_dim))
                
                if rtg[-1].shape[1] <= s[-1].shape[1]:
                    rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, self.rtg_dim))], axis=1)
            # rvs: use future avg, and look until the end
            else:
                if self.rtg_dim == 1:
                    rtg.append(self.find_avg_rtg(traj['rewards'][step_start:self.max_ep_len]).reshape(1, -1, self.rtg_dim))
                else:
                    rtg.append(self.find_avg_rtg_mo(traj['raw_rewards'][step_start:self.max_ep_len]).reshape(1, -1, self.rtg_dim))
            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            a[-1] = np.concatenate([np.ones((1, self.max_len - tlen, self.act_dim)) * -0., a[-1]], axis=1)
            raw_r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.pref_dim)), raw_r[-1]], axis=1)
            pref[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.pref_dim)), pref[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.rtg_dim)), rtg[-1]], axis=1)
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=self.device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=self.device)
        raw_r = torch.from_numpy(np.concatenate(raw_r, axis=0)).to(dtype=torch.float32, device=self.device) / self.scale
        pref = torch.from_numpy(np.concatenate(pref, axis=0)).to(dtype=torch.float32, device=self.device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=self.device) / self.scale
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=self.device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=self.device)
        return s, a, raw_r, rtg, timesteps, mask, pref


class PromptTrajReplayBuffer():
    
    def __init__(
        self,
        args,
        pref_label,
        adpt_batch_size,
        avg_rtg = False,
    ):
        self.max_len = 20
        self.max_ep_len = args.max_episode_len
        self.state_dim = args.obs_shape
        self.act_dim = args.action_shape
        self.pref_dim = args.reward_size
        self.rtg_dim = args.reward_size
        self.scale = 1.0
        self.device = args.device
        self.avg_rtg = avg_rtg
        self.gamma = 1.0
        self.pref_label=pref_label
        self.batch_per_task = 8
        self.demo_per_task = 1
        self.adpt_batch_size = adpt_batch_size
        self.max_task_per_batch = 16
        
    def load_from_dataset(self, env, dataset):  
        self.trajectories = dataset
        n_pref = len(self.pref_label)
        trajs_in_pref = [list() for _ in range(n_pref)]
        self.sub_dataset = []
        self.sub_demo = []
        self.sub_dataset_p_sample = []
        self.sub_demo_p_sample = []
        for traj in self.trajectories:
            pref = traj['preference'][0].reshape(1, -1).repeat(n_pref, axis=0)
            similar = cal_vec_cos(pref, self.pref_label)
            idx = np.argmax(similar)
            trajs_in_pref[idx].append(traj)

        for idx, pref in enumerate(self.pref_label):
            n_demo = max(int(len(trajs_in_pref[idx]) * 0.1), 1)
            if len(trajs_in_pref[idx])<=1: continue
            utility = np.array([np.sum(pref*np.sum(traj['raw_rewards'], axis=0)) for traj in trajs_in_pref[idx]])
            sort_idx = np.argsort(utility)
            self.sub_demo.append([trajs_in_pref[idx][x] for x in sort_idx[-n_demo:]])
            self.sub_dataset.append([trajs_in_pref[idx][x] for x in sort_idx[0:-n_demo]])

            traj_lens = np.array([len(traj['observations']) for traj in self.sub_dataset[-1]])
            self.sub_dataset_p_sample.append(traj_lens / sum(traj_lens))
            traj_lens = np.array([len(traj['observations']) for traj in self.sub_demo[-1]])
            self.sub_demo_p_sample.append(traj_lens / sum(traj_lens))

       
    def discount_cumsum(self, x):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0]-1)):
            discount_cumsum[t] = x[t] + self.gamma * discount_cumsum[t+1]
        return discount_cumsum

    def discount_cumsum_mo(self, x_mo):
        return np.transpose(np.array([self.discount_cumsum(x_mo[:,i]) for i in range(x_mo.shape[1])]))

    
    def find_avg_rtg(self, x):
        return np.mean(x)

    def find_avg_rtg_mo(self, x_mo):
        return np.mean(x_mo, axis=0)

    def get_traj_seg(self, traj, seg_len, maximize_seglen=False):
        # randomly get the starting idx
        max_step_start = traj['observations'].shape[0]-1 if not maximize_seglen else max(0, traj['observations'].shape[0]-seg_len)
        step_start = random.randint(0, max_step_start)
        step_end = step_start + seg_len

        s=traj['observations'][step_start:step_end].reshape(1, -1, self.state_dim)
        a=traj['actions'][step_start:step_end].reshape(1, -1, self.act_dim) # assume scale if relflective to 0 (-x, x)
        raw_r=traj['raw_rewards'][step_start:step_end].reshape(1, -1, self.pref_dim)
        pref=traj['preference'][step_start:step_end].reshape(1, -1, self.pref_dim)
        timesteps=np.arange(step_start, step_start + s.shape[1]).reshape(1, -1)
        timesteps[timesteps >= self.max_ep_len] = self.max_ep_len-1  # padding cutoff
        
        # non-rvs: use discount cumsum
        if not self.avg_rtg:
            if self.rtg_dim == 1:
                rtg=self.discount_cumsum(traj['rewards'][step_start:step_end]).reshape(1, -1, self.rtg_dim)
            else:
                rtg=self.discount_cumsum_mo(traj['raw_rewards'][step_start:step_end]).reshape(1, -1, self.rtg_dim)
            if rtg.shape[1] <= s.shape[1]:
                rtg = np.concatenate([rtg, np.zeros((1, 1, self.rtg_dim))], axis=1)
        # rvs: use future avg, and look until the end
        else:
            if self.rtg_dim == 1:
                rtg=self.find_avg_rtg(traj['rewards'][step_start:self.max_ep_len]).reshape(1, -1, self.rtg_dim)
            else:
                rtg=self.find_avg_rtg_mo(traj['raw_rewards'][step_start:self.max_ep_len]).reshape(1, -1, self.rtg_dim)
        return s, a, raw_r, rtg, timesteps, pref

    def sample(self, batch_size):
        s_batch, a_batch, r_batch, pref_batch, rtg_batch, timesteps_batch, mask_batch, target_mask_batch = [], [], [], [], [], [], [], []
        raw_r = []

        task_idx = np.random.choice(np.arange(0, len(self.sub_dataset)), size=min(self.max_task_per_batch, len(self.sub_dataset)), replace=False)

        for idx in task_idx:
            batch_inds = np.random.choice(
                np.arange(len(self.sub_dataset[idx])),
                size=self.batch_per_task,
                replace=True,
                p=self.sub_dataset_p_sample[idx],
            )
            for traj_idx in batch_inds:
                s_sub_batch, a_sub_batch, r_sub_batch, pref_sub_batch, rtg_sub_batch, timesteps_sub_batch = [], [], [], [], [], []
                demo_inds = np.random.choice(
                    np.arange(len(self.sub_demo[idx])),
                    size=self.demo_per_task,
                    replace=True,
                    p=self.sub_demo_p_sample[idx],
                )
                for demo_idx in demo_inds:
                    s, a, raw_r, rtg, timesteps, pref = self.get_traj_seg(self.sub_demo[idx][demo_idx], self.adpt_batch_size//len(demo_inds))
                    s_sub_batch.append(s)
                    a_sub_batch.append(a)
                    r_sub_batch.append(raw_r)
                    pref_sub_batch.append(pref)
                    rtg_sub_batch.append(rtg[:,:-1,:])
                    timesteps_sub_batch.append(timesteps)
                
                s, a, raw_r, rtg, timesteps, pref = self.get_traj_seg(self.sub_dataset[idx][traj_idx], self.max_len)
                s_sub_batch.append(s)
                a_sub_batch.append(a)
                r_sub_batch.append(raw_r)
                pref_sub_batch.append(pref)
                rtg_sub_batch.append(rtg)
                timesteps_sub_batch.append(timesteps)

                # padding and state + reward normalization
                tlen = np.sum([x.shape[1] for x in s_sub_batch])
                traj_len = s_sub_batch[-1].shape[1]
                seg_len = self.max_len + self.adpt_batch_size
                s_batch.append(np.concatenate([np.zeros((1, seg_len - tlen, self.state_dim))]+s_sub_batch, axis=1))
                a_batch.append(np.concatenate([np.ones((1, seg_len - tlen, self.act_dim))]+a_sub_batch, axis=1))
                r_batch.append(np.concatenate([np.zeros((1, seg_len - tlen, self.pref_dim))]+r_sub_batch, axis=1))
                pref_batch.append(np.concatenate([np.zeros((1, seg_len - tlen, self.pref_dim))]+pref_sub_batch, axis=1))
                rtg_batch.append(np.concatenate([np.zeros((1, seg_len - tlen, self.rtg_dim))]+rtg_sub_batch, axis=1))
                timesteps_batch.append(np.concatenate([np.zeros((1, seg_len - tlen))]+timesteps_sub_batch, axis=1))
                mask_batch.append(np.concatenate([np.zeros((1, seg_len - tlen)), np.ones((1, tlen))], axis=1))

        s_batch = torch.from_numpy(np.concatenate(s_batch, axis=0)).to(dtype=torch.float32, device=self.device)
        a_batch = torch.from_numpy(np.concatenate(a_batch, axis=0)).to(dtype=torch.float32, device=self.device)
        r_batch = torch.from_numpy(np.concatenate(r_batch, axis=0)).to(dtype=torch.float32, device=self.device) / self.scale
        pref_batch = torch.from_numpy(np.concatenate(pref_batch, axis=0)).to(dtype=torch.float32, device=self.device)
        rtg_batch = torch.from_numpy(np.concatenate(rtg_batch, axis=0)).to(dtype=torch.float32, device=self.device) / self.scale
        timesteps_batch = torch.from_numpy(np.concatenate(timesteps_batch, axis=0)).to(dtype=torch.long, device=self.device)
        mask_batch = torch.from_numpy(np.concatenate(mask_batch, axis=0)).to(device=self.device)


        return s_batch, a_batch, r_batch, rtg_batch, timesteps_batch, mask_batch, pref_batch
