import random
import pickle
import torch
import numpy as np
from torch.utils.data import Dataset
import os

MAX_EPISODE_LEN = 1000000


def discount_cumsum(x, gamma):
    disc_cumsum = np.zeros_like(x)
    disc_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1]
    return disc_cumsum


class D4RLTrajectoryDataset(Dataset):
    def __init__(self, dataset_path, context_len, rtg_scale, data_ratio=1.0,reward_tune=None,sample_size=1000,normalize=True,critic=None):

        self.context_len = context_len
        self.reward_tune = reward_tune
        self.rtg_scale = rtg_scale


        # load dataset
        with open(dataset_path, "rb") as f:
            self.trajectories = pickle.load(f)

            size = len(self.trajectories)
            if data_ratio < 1.0:
                new_size = int(size * data_ratio)
                self.trajectories = self.trajectories[:new_size]


        # calculate min len of traj, state mean and variance
        # and returns_to_go for all traj
        min_len = 10**6
        states = []
        advantages=[]
        for traj in self.trajectories:
            traj_len = traj["observations"].shape[0]
            min_len = min(min_len, traj_len)
            states.append(traj["observations"])
            # calculate returns to go and rescale them
            traj["returns_to_go"] = (discount_cumsum(traj["rewards"], 1.0) / rtg_scale)
            traj["returns_to_go_discounted"] = (discount_cumsum(traj["rewards"], 0.99) / rtg_scale)
            traj["advantage"]= (traj["returns_to_go"] / traj["observations"].shape[0])*rtg_scale
            advantages.append(traj["advantage"])
            if critic:
                with torch.no_grad():
                    q1, q2 = critic(torch.FloatTensor(traj["observations"]),torch.FloatTensor(traj["actions"]))
                    q_values = torch.minimum(q1, q2).detach().numpy().flatten()
                    traj["q"] = q_values

        # advantage
        advantages= np.concatenate(advantages, axis=0)
        self.mean_advantage = np.mean(advantages)
        for traj in self.trajectories:
            traj["advantage"] = traj["advantage"] - self.mean_advantage

        # used for input normalization
        states = np.concatenate(states, axis=0)
        self.state_dim = states.shape[-1]
        self.act_dim = self.trajectories[0]["actions"].shape[-1]
        if normalize:
            self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
        else:
            self.state_mean = np.zeros(states.shape[1])
            self.state_std = np.ones(states.shape[1])


        # normalize states
        print(f"num of trajs: {len(self.trajectories)}")
        traj_lens, returns = [], []
        for traj in self.trajectories:
            traj["observations"] = (traj["observations"] - self.state_mean) / self.state_std
            traj["next_observations"] = (traj["next_observations"] - self.state_mean) / self.state_std
            traj_lens.append(len(traj["observations"]))
            returns.append(traj['rewards'].sum())

        #----------------------print dataset stats----------------------
        traj_lens, returns = np.array(traj_lens), np.array(returns)
        num_timesteps = sum(traj_lens)
        self.max_return= np.max(returns)

        print('=' * 50)
        print(f'Starting new experiment: {dataset_path}')
        print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
        print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
        print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
        print('=' * 50)

        self.sampling_ind = self.sample_trajs(self.trajectories, sample_size)

    def sample_trajs(self,trajectories, sample_size):

        traj_lens = np.array([len(traj["observations"]) for traj in trajectories])
        p_sample = traj_lens / np.sum(traj_lens)

        inds = np.random.choice(
            np.arange(len(trajectories)),
            size=sample_size,
            replace=True,
            p=p_sample,
        )
        return inds

    def get_state_stats(self):
        return self.state_mean, self.state_std

    def __len__(self):
        return len(self.sampling_ind)

    def __getitem__(self, idx):


        idx = self.sampling_ind[idx]
        traj = self.trajectories[idx]
        traj_return = traj["rewards"].sum()
        traj["traj_returns"] = np.array([traj_return for i in range(len(traj["rewards"]))])
        traj_len = traj['observations'].shape[0]

        si = random.randint(0, traj_len - 1)

        # get sequences from dataset
        states = traj["observations"][si: si + self.context_len].reshape(-1, self.state_dim)
        next_states = traj["next_observations"][si: si + self.context_len].reshape(-1, self.state_dim)
        actions = traj["actions"][si: si + self.context_len].reshape(-1, self.act_dim)
        rewards = traj["rewards"][si: si + self.context_len].reshape(-1, 1)
        tr = traj["traj_returns"][si: si + self.context_len].reshape(-1, 1)
        returns_to_go = traj["returns_to_go"][si: si + self.context_len+1].reshape(-1, 1)
        if returns_to_go.shape[0] <= states.shape[0]:
            returns_to_go = (np.concatenate([returns_to_go,np.zeros((1, 1))]))

        advantage = traj["advantage"][si: si + self.context_len + 1].reshape(-1, 1)
        if advantage.shape[0] <= states.shape[0]:
            advantage = (np.concatenate([advantage, np.zeros((1, 1))]))

        returns_to_go_discounted = traj["returns_to_go_discounted"][si: si + self.context_len+1].reshape(-1, 1)
        if returns_to_go_discounted.shape[0] <= states.shape[0]:
            returns_to_go_discounted = (np.concatenate([returns_to_go_discounted,np.zeros((1, 1))]))


        if "terminals" in traj:
            terminals = traj["terminals"][si: si + self.context_len]  # .reshape(-1)
        else:
            terminals = traj["dones"][si: si + self.context_len]  # .reshape(-1)

        # get the total length of a trajectory
        tlen = states.shape[0]

        timesteps = np.arange(si, si + tlen)  # .reshape(-1)
        ordering = np.arange(tlen)
        ordering[timesteps >= MAX_EPISODE_LEN] = -1
        ordering[ordering == -1] = ordering.max()
        timesteps[timesteps >= MAX_EPISODE_LEN] = MAX_EPISODE_LEN - 1  # padding cutoff


        # padding and state + reward normalization
        act_len = actions.shape[0]
        if tlen != act_len:
            raise ValueError

        padd_len= self.context_len - tlen

        states = np.concatenate([np.zeros((padd_len, self.state_dim)), states])
        # states = (states - self.state_mean) / self.state_std

        next_states = np.concatenate([np.zeros((padd_len, self.state_dim)), next_states])

        actions = np.concatenate([np.zeros((padd_len, self.act_dim)), actions])
        rewards = np.concatenate([np.zeros((padd_len, 1)), rewards])
        terminals = np.concatenate([np.ones((padd_len)) * 1, terminals])
        tr = np.concatenate([np.zeros((padd_len, 1)), tr])
        returns_to_go = (np.concatenate([np.zeros((padd_len, 1)), returns_to_go]))
        returns_to_go_discounted = (np.concatenate([np.zeros((padd_len, 1)), returns_to_go_discounted]))
        advantage = (np.concatenate([np.zeros((padd_len, 1)), advantage]))
        # if returns_to_go.shape[0] <= self.context_len:
        #     returns_to_go = (np.concatenate([np.zeros((1, 1)), returns_to_go]))

        timesteps = np.concatenate([np.zeros((padd_len)), timesteps])
        ordering = np.concatenate([np.zeros((padd_len)), ordering])
        traj_mask = np.concatenate([np.zeros(padd_len), np.ones(tlen)])

        states = torch.from_numpy(states).to(dtype=torch.float32)
        next_states = torch.from_numpy(next_states).to(dtype=torch.float32)
        actions = torch.from_numpy(actions).to(dtype=torch.float32)
        rewards = torch.from_numpy(rewards).to(dtype=torch.float32)
        terminals = torch.from_numpy(terminals).to(dtype=torch.long)
        tr = torch.from_numpy(tr).to(dtype=torch.float32)
        returns_to_go = torch.from_numpy(returns_to_go).to(dtype=torch.float32)
        returns_to_go_discounted = torch.from_numpy(returns_to_go_discounted).to(dtype=torch.float32)
        advantage = torch.from_numpy(advantage).to(dtype=torch.float32)
        timesteps = torch.from_numpy(timesteps).to(dtype=torch.long)
        ordering = torch.from_numpy(ordering).to(dtype=torch.long)
        traj_mask = torch.from_numpy(traj_mask)

        target_a = actions.clone()
        return {
            "states": states,
            "actions":actions,
            "rewards":rewards,
            "target_a":target_a,
            "terminals":terminals,
            "returns_to_go":returns_to_go,
            "timesteps":timesteps,
            "traj_mask":traj_mask,
            "next_states":next_states,
            "traj_returns":tr,
            "returns_to_go_discounted":returns_to_go_discounted,
            "advantage": advantage
        }





## calculated from d4rl datasets
def get_d4rl_dataset_stats(data_file_name):
    data_file_name= f"{data_file_name}.pkl"
    dataset_path = "data/"+data_file_name
    # load dataset
    with open(dataset_path, "rb") as f:
        trajectories = pickle.load(f)
    # calculate min len of traj, state mean and variance
    # and returns_to_go for all traj
    min_len = 10 ** 6
    states = []
    rewards=[]
    returns=[]
    for traj in trajectories:
        traj_len = traj["observations"].shape[0]
        min_len = min(min_len, traj_len)
        states.append(traj["observations"])
        returns.append(traj["rewards"].sum())
        rewards.append(traj["rewards"])
    # used for input normalization
    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

    max_returns =np.max(returns)

    return {"state_mean":state_mean, "state_std":state_std, "max_returns":max_returns}

# get_d4rl_dataset_stats("halfcheetah-expert-v2")