# Copyright (c) 2022 Nikhil Barhate
# Adapted from https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/refs/heads/master/decision_transformer/utils.py
# Modifications Copyright (c) 2025 King.com Ltd
import random
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
from sys import platform


def get_device(use_cuda=True):
    if use_cuda:
        if platform == "linux" or platform == "linux2":
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        elif platform == "darwin":
            device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
        elif platform == "win32":
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        else:
            raise ValueError(f"Unknown platform: {platform}")
    else:
        device = torch.device("cpu")

    return device


def discount_cumsum(x, gamma):
    disc_cumsum = np.zeros_like(x)
    disc_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        disc_cumsum[t] = x[t] + gamma * disc_cumsum[t+1]
    return disc_cumsum


def evaluate_on_pointDirEnv(model, device, context_len, env, rtg_target, rtg_scale,
                            num_eval_ep=10, max_test_ep_len=1000,
                            state_mean=None, state_std=None, render=False,
                            state_dim=None, act_dim=None,
                            env_id=None, n_traj_prompt_segments=None, traj_prompt_seg_len=None,
                            train_datasets=None, use_state_dims=None, scale_actions=False,
                            ):

    eval_batch_size = 1

    results = {}
    total_reward = 0
    total_timesteps = 0

    state_dim = env.observation_space.shape[0] if state_dim is None else state_dim
    act_dim = env.action_space.shape[0] if act_dim is None else act_dim

    if state_mean is None:
        state_mean = torch.zeros((state_dim,)).to(device)
    else:
        state_mean = torch.from_numpy(state_mean).to(device)

    if state_std is None:
        state_std = torch.ones((state_dim,)).to(device)
    else:
        state_std = torch.from_numpy(state_std).to(device)

    timesteps = torch.arange(start=0, end=max_test_ep_len, step=1)
    timesteps = timesteps.repeat(eval_batch_size, 1).to(device)

    model.eval()

    with torch.no_grad():

        print(f"Starting to collect eval rollouts for rtg {rtg_target}...")
        trajectories = []
        returns = []
        target_angles = []
        for _ in range(num_eval_ep):

            # zeros place holders
            actions = torch.zeros((eval_batch_size, max_test_ep_len, act_dim), dtype=torch.float32, device=device)
            states = torch.zeros((eval_batch_size, max_test_ep_len, state_dim), dtype=torch.float32, device=device)
            rewards_to_go = torch.zeros((eval_batch_size, max_test_ep_len, 1), dtype=torch.float32, device=device)

            # init episode
            gymnasium_api = False
            running_state = env.reset()
            if type(running_state) is tuple:
                if type(running_state[0]) is np.ndarray and type(running_state[1]) is dict:
                    gymnasium_api = True
                    info = running_state[1]
                    running_state = running_state[0][use_state_dims]

            if "agent_pos" in info:
                agent_pos_hist = [info["agent_pos"]]
            else:
                agent_pos_hist = []
            reward_sum = 0
            running_reward = 0
            running_rtg = rtg_target / rtg_scale

            if "target_angle_rad" in info:
                target_angle = info["target_angle_rad"]
                target_radius = info["target_radius"]
            else:
                print("PDT eval: No target angle provided, this shouldn't happen when training on 2D point envs!")
                target_angle = 0
                target_radius = 0

            if model.which_model == "traj_pdt":
                if train_datasets is not None:

                    ds_distances = []
                    for ds in train_datasets:
                        # find the dataset for the env we are evaluating on...
                        ds_radius = ds.dataset_path.split("radius")[1]
                        ds_radius = float(ds_radius.split("-")[0])
                        radius_dist = abs(ds_radius - target_radius)

                        ds_angle = ds.dataset_path.split("angle")[1]
                        ds_angle = float(ds_angle.split("_")[0]) * np.pi
                        angle_dist = abs(ds_angle - target_angle)

                        total_dist = radius_dist + angle_dist

                        ds_distances.append(total_dist)

                    target_dataset = train_datasets[np.argmin(ds_distances)]

                    traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, _, _ = target_dataset._sample_traj_prompt()

                    traj_prompt_states = traj_prompt_states.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, state_dim).to(device)
                    traj_prompt_actions = traj_prompt_actions.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, act_dim).to(device)
                    traj_prompt_rtgs = traj_prompt_rtgs.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len, 1).to(device)
                    traj_prompt_timesteps = traj_prompt_timesteps.reshape(eval_batch_size, n_traj_prompt_segments * traj_prompt_seg_len).to(device)

            for t in range(max_test_ep_len):

                total_timesteps += 1
                running_state = running_state[use_state_dims]

                # add state in placeholder and normalize
                states[0, t] = torch.from_numpy(running_state).to(torch.float32).to(device)
                states[0, t] = (states[0, t] - state_mean) / state_std

                # calcualate running rtg and add it in placeholder
                running_rtg = running_rtg - (running_reward / rtg_scale)
                rewards_to_go[0, t] = torch.tensor(running_rtg, device=rewards_to_go.device, dtype=rewards_to_go.dtype)

                if t < context_len:
                    _, act_preds, _, _, _, _, _ = model.forward(
                        timesteps=timesteps[:,:context_len],
                        states=states[:,:context_len],
                        actions=actions[:,:context_len],
                        returns_to_go=rewards_to_go[:,:context_len],
                        traj_prompt_timesteps = traj_prompt_timesteps if model.which_model == "traj_pdt" else None,
                        traj_prompt_states = traj_prompt_states if model.which_model == "traj_pdt" else None,
                        traj_prompt_actions = traj_prompt_actions if model.which_model == "traj_pdt" else None,
                        traj_prompt_rtgs = traj_prompt_rtgs if model.which_model == "traj_pdt" else None
                    )
                    act = act_preds[0, t].detach()
                else:
                    _, act_preds, _, _, _, _, _ = model.forward(
                        timesteps=timesteps[:,t-context_len+1:t+1],
                        states=states[:,t-context_len+1:t+1],
                        actions=actions[:,t-context_len+1:t+1],
                        returns_to_go=rewards_to_go[:,t-context_len+1:t+1],
                        traj_prompt_timesteps = traj_prompt_timesteps if model.which_model == "traj_pdt" else None,
                        traj_prompt_states = traj_prompt_states if model.which_model == "traj_pdt" else None,
                        traj_prompt_actions = traj_prompt_actions if model.which_model == "traj_pdt" else None,
                        traj_prompt_rtgs = traj_prompt_rtgs if model.which_model == "traj_pdt" else None
                    )
                    act = act_preds[0, -1].detach()

                if act_dim == 1 and act.dtype == torch.float32:
                    int_actions = torch.tensor([0, 1, 2, 3]).to(torch.float32).to(act.device)
                    act = torch.argmin(torch.abs(int_actions - act), keepdim=True)

                do_action = act.cpu().numpy()
                if scale_actions:
                    low = env.action_space.low
                    high = env.action_space.high
                    scaled_action = low + (high - low) * act.detach().cpu().numpy()
                    do_action = scaled_action

                if gymnasium_api:
                    running_state, running_reward, done, trunc, info = env.step(do_action)
                    done = done or trunc
                else:
                    running_state, running_reward, done, info = env.step(do_action)

                # add action in placeholder
                actions[0, t] = act

                total_reward += running_reward
                reward_sum += running_reward

                if "agent_pos" in info:
                    agent_pos_hist.append(info["agent_pos"])

                if render:
                    env.render()
                if done:
                    break

            trajectories.append(agent_pos_hist)
            returns.append(reward_sum)
            target_angles.append(target_angle)

    print("Finished collecting eval rollouts...")
    results['eval/avg_reward'] = total_reward / num_eval_ep
    results['eval/avg_ep_len'] = total_timesteps / num_eval_ep
    results['eval_returns'] = returns
    results['trajectories'] = trajectories
    results['target_angles'] = target_angles

    return results


class TrajectoryDataset(Dataset):
    def __init__(self,
                 dataset_path,
                 context_len,
                 rtg_scale,
                 traj_prompt_j=None,  # the number of trajectory segments in the traj prompt
                 traj_prompt_h=None,  # the number of steps per segment in the traj prompt
                 use_state_dims=None,  # a list to specify which dimensions of the state to use, dims not in the list will be discarded
                 traj_prompt_segment_start_lower_bound=None,  # for controlling the segments sampled for traj prompts
                 traj_prompt_segment_start_upper_bound=None,  # for controlling the segments sampled for traj prompts
                 traj_prompt_noise_scale=0.0,  # scale for traj prompt noise
                 use_sparse_reward=False,  # override logged reward with info['sparse_reward']
                 use_every_nth_traj=None,  # only use every n-th trajectory to save RAM
                 ):

        self.dataset_path = dataset_path
        self.context_len = context_len

        assert traj_prompt_j is not None
        assert traj_prompt_h is not None
        assert traj_prompt_j > 0, "traj_prompt_j should be greater than 0"
        assert traj_prompt_h > 0, "traj_prompt_h should be greater than 0"

        self.traj_prompt_j = traj_prompt_j
        self.traj_prompt_h = traj_prompt_h

        self.traj_prompt_segment_start_lower_bound = traj_prompt_segment_start_lower_bound
        self.traj_prompt_segment_start_upper_bound = traj_prompt_segment_start_upper_bound
        self.traj_prompt_noise_scale = traj_prompt_noise_scale

        with open(dataset_path, 'rb') as f:
            self.trajectories = pickle.load(f)
        print(f"Loaded {len(self.trajectories)} trajectories from {dataset_path}")

        if use_every_nth_traj is not None:
            self.trajectories = self.trajectories[::use_every_nth_traj]

        min_len = 10**6
        states = []
        returns = []
        for traj_idx, traj in enumerate(self.trajectories):
            traj_len = traj['observations'].shape[0]

            if use_state_dims is not None and len(use_state_dims) < traj['observations'].shape[1]:
                if traj_idx == 0:
                    print(f"Using only dims {use_state_dims} from {traj['observations'].shape[1]} dims in the state!")
                traj['observations'] = traj['observations'][:, use_state_dims]

            min_len = min(min_len, traj_len)
            states.append(traj['observations'])

            if use_sparse_reward:
                for t in range(traj_len):
                    try:
                        traj['rewards'][t] = traj['infos'][t]['sparse_reward']
                    except KeyError:
                        print(f"No sparse_reward in trajectory infos from dataset: {self.dataset_path}, therefore using logged step rewards")
                        break

            traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale

            returns.append(np.sum(traj['rewards']))
            traj['return'] = np.sum(traj['rewards'])

            onehot_prompts = []
            try:
                for i in range(traj_len):
                    onehot_prompts.append(traj["infos"][i]["which_goal_oneHot"])

                onehot_prompts = np.array(onehot_prompts)
                assert np.all(onehot_prompts[0] == onehot_prompts), "oneHot prompts should not change during trajectory!"
                traj["onehot_prompt"] = onehot_prompts
            except KeyError:
                pass

            traj_prompt = []
            traj["trajectory_prompt"] = traj_prompt

        # used for input normalization
        states = np.concatenate(states, axis=0)
        self.state_mean, self.state_std = np.mean(states, axis=0, dtype=np.float32), np.std(states, axis=0, dtype=np.float32) + 1e-6

        # gather high return trajectories to be used as expert prompts
        self.return_mean, self.return_std = np.mean(returns), np.std(returns) + 1e-6
        self.all_returns = returns
        self.sorted_returns = sorted(self.all_returns)
        self.expert_prompt_trajs = []
        bottom_percentile_return = self.sorted_returns[int(0.01 * len(self.sorted_returns))]
        self.novice_prompt_trajs = []

        for traj in self.trajectories:
            traj_return = traj['return']
            traj_len = len(traj["rewards"])
            if traj_return >= self.sorted_returns[int(0.95 * len(self.sorted_returns))]:
                if traj_len > self.traj_prompt_h:
                    self.expert_prompt_trajs.append(traj)

            elif traj_return <= bottom_percentile_return:
                if traj_len > self.traj_prompt_h:
                    self.novice_prompt_trajs.append(traj)

        mixture_dataset_size = 10
        self.mixture_datasets = {}
        for expert_percentage in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
            expert_subset_size = int(expert_percentage * mixture_dataset_size)
            novice_subset_size = mixture_dataset_size - expert_subset_size
            mixture_dataset_name = f"mixture-{int(expert_percentage * 100)}percent-expert"

            expert_subset = random.sample(self.expert_prompt_trajs, expert_subset_size)
            novice_subset = random.sample(self.novice_prompt_trajs, novice_subset_size)

            mixture_dataset = expert_subset + novice_subset
            random.shuffle(mixture_dataset)

            self.mixture_datasets[mixture_dataset_name] = mixture_dataset

    def normalize_states(self):
        for traj in self.trajectories:
            traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std

    def get_state_stats(self):
        return self.state_mean, self.state_std

    def __len__(self):
        return len(self.trajectories)

    def _sample_traj_prompt(self, which_data="expert"):
        valid_segments_sampled = 0
        traj_prompt_state_segments = []
        traj_prompt_action_segments = []
        traj_prompt_rtg_segments = []
        traj_prompt_time_segments = []
        traj_prompt_info_segments = []
        traj_prompt_segment_idxs = []

        if which_data == "expert":
            use_trajs = self.expert_prompt_trajs
        elif which_data in self.mixture_datasets.keys():
            use_trajs = self.mixture_datasets[which_data]
        else:
            raise ValueError(f"which_data should be 'expert' or one of {self.mixture_datasets.keys()}")

        assert len(use_trajs) > 0, "No expert prompt trajectories found!"

        while valid_segments_sampled < self.traj_prompt_j:
            # sample a random trajectory for the segment
            traj_prompt_segment_idx = np.random.randint(len(use_trajs))
            prompt_traj = use_trajs[traj_prompt_segment_idx]

            # sample a random slice of the trajectory, making sure that we don't start too close to the end
            prompt_segment_traj_len = prompt_traj['observations'].shape[0]
            upper_limit_segment_start = prompt_segment_traj_len - self.traj_prompt_h - 1
            if upper_limit_segment_start < 0:
                # if the trajectory is too short, we can't sample a segment of h steps
                continue

            traj_prompt_segment_start = np.random.randint(0, upper_limit_segment_start)
            traj_prompt_segment_end = traj_prompt_segment_start + self.traj_prompt_h

            traj_segment_actions = torch.from_numpy(prompt_traj['actions'][traj_prompt_segment_start: traj_prompt_segment_end]).to(torch.float32)
            traj_segment_states = torch.from_numpy(prompt_traj['observations'][traj_prompt_segment_start: traj_prompt_segment_end]).to(torch.float32)
            traj_segment_rtgs = torch.from_numpy(prompt_traj['returns_to_go'][traj_prompt_segment_start: traj_prompt_segment_end]).to(torch.float32)

            for i in range(traj_prompt_segment_start, traj_prompt_segment_end):
                traj_prompt_info_segments.append(prompt_traj['infos'][i])
            traj_prompt_segment_idxs.append([traj_prompt_segment_idx] * self.traj_prompt_h)

            traj_segment_timesteps = torch.arange(start=traj_prompt_segment_start, end=traj_prompt_segment_end, step=1)
            traj_prompt_action_segments.append(traj_segment_actions)
            traj_prompt_state_segments.append(traj_segment_states)
            traj_prompt_rtg_segments.append(traj_segment_rtgs)
            traj_prompt_time_segments.append(traj_segment_timesteps)

            valid_segments_sampled += 1

        traj_prompt_states = torch.stack(traj_prompt_state_segments, dim=0).reshape(self.traj_prompt_j * self.traj_prompt_h, -1)
        traj_prompt_actions = torch.stack(traj_prompt_action_segments, dim=0).reshape(self.traj_prompt_j * self.traj_prompt_h, -1)
        traj_prompt_rtgs = torch.stack(traj_prompt_rtg_segments, dim=0).reshape(self.traj_prompt_j * self.traj_prompt_h)
        traj_prompt_timesteps = torch.stack(traj_prompt_time_segments, dim=0).reshape(self.traj_prompt_j * self.traj_prompt_h)

        traj_prompt_states += torch.randn_like(traj_prompt_states) * self.traj_prompt_noise_scale
        traj_prompt_actions += torch.randn_like(traj_prompt_actions) * self.traj_prompt_noise_scale
        traj_prompt_rtgs += torch.randn_like(traj_prompt_rtgs) * self.traj_prompt_noise_scale

        def flatten(xss):
            return [x for xs in xss for x in xs]

        traj_prompt_segment_idxs = flatten(traj_prompt_segment_idxs)

        return traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, traj_prompt_info_segments, traj_prompt_segment_idxs


    def __getitem__(self, idx):
        traj = self.trajectories[idx]
        traj_len = traj['observations'].shape[0]
        assert traj['observations'].shape[0] == traj['actions'].shape[0] == traj['returns_to_go'].shape[0]

        traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs, _, _ = self._sample_traj_prompt()

        if traj_len >= self.context_len:
            # sample random index to slice trajectory
            si = random.randint(0, traj_len - self.context_len)

            states = torch.from_numpy(traj['observations'][si : si + self.context_len]).to(torch.float32)
            actions = torch.from_numpy(traj['actions'][si : si + self.context_len]).to(torch.float32)
            returns_to_go = torch.from_numpy(traj['returns_to_go'][si : si + self.context_len]).to(torch.float32)
            timesteps = torch.arange(start=si, end=si+self.context_len, step=1)

            # all ones since no padding
            traj_mask = torch.ones(self.context_len, dtype=torch.long)

        else:
            padding_len = self.context_len - traj_len

            # padding with zeros
            states = torch.from_numpy(traj['observations']).to(torch.float32)
            states = torch.cat([states, torch.zeros(([padding_len] + list(states.shape[1:])),dtype=states.dtype)], dim=0)

            actions = torch.from_numpy(traj['actions']).to(torch.float32)
            actions = torch.cat([actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0)

            returns_to_go = torch.from_numpy(traj['returns_to_go']).to(torch.float32)
            returns_to_go = torch.cat([returns_to_go, torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype)], dim=0)

            timesteps = torch.arange(start=0, end=self.context_len, step=1)

            traj_mask = torch.cat([torch.ones(traj_len, dtype=torch.long), torch.zeros(padding_len, dtype=torch.long)], dim=0)

        assert traj_prompt_states.shape[-1] == states.shape[-1], "Prompt states and states should have same dim!"
        assert traj_prompt_actions.shape[-1] == actions.shape[-1], "Prompt actions and actions should have same dim!"

        return  timesteps, states, actions, returns_to_go, traj_mask, traj_prompt_timesteps, traj_prompt_states, traj_prompt_actions, traj_prompt_rtgs
