import numpy as np
from rlf.il.il_dataset import convert_to_tensors
from goal_prox.method.goal_traj_dataset import GoalTrajDataset
import torch
from goal_prox.envs.gw_helper import *

def exp_discounted(T, t, delta):
    return np.power(delta, T - t)


def exp_discounted_subsample(T, t, delta, subsample_freq=10):
    return np.power(delta, (T - t) // subsample_freq)

def linear_discounted(T, t, delta):
    return max(1.0 - ((T - t) * delta), 0.0)

def big_discounted(T, t, delta, start_val):
    return max(start_val - ((T - t) * delta), 0.0)

def compute_discounted_prox(T, compute_prox_fn):
    return np.array([compute_prox_fn(T, t + 1) for t in range(T)],
            dtype=np.float32)


class ValueTrajDataset(GoalTrajDataset):
    def get_prox_stats(self):
        proxs = [x[1] for x in self.data]
        return np.min(proxs), np.max(proxs)

    def __init__(self, load_path, compute_prox_fn, args):
        self.compute_prox_fn = compute_prox_fn
        self.args = args
        super().__init__(load_path)

    def _gen_data(self, trajs):
        data = []
        for states, actions in trajs:
            T = len(states)
            proxs = torch.tensor(compute_discounted_prox(T, self.compute_prox_fn))
            # The last action is all 0.
            use_actions = torch.cat([actions, torch.zeros(1, *actions.shape[1:])], dim=0)
            data.extend(zip(states, proxs, use_actions))
        return data

    def __getitem__(self, i):
        return {
                'state': self.data[i][0],
                'prox': self.data[i][1],
                'actions': self.data[i][2]
                }



class ValueTrajIncludeNextStateDataset(GoalTrajDataset):
    def get_prox_stats(self):
        proxs = [x[1] for x in self.data]
        return np.min(proxs), np.max(proxs)

    def __init__(self, load_path, compute_prox_fn, args):
        self.compute_prox_fn = compute_prox_fn
        self.args = args
        super().__init__(load_path)

    def _gen_data(self, trajs):
        data = []
        for states, actions, next_states in trajs:
            T = len(states)
            proxs = torch.tensor(compute_discounted_prox(T, self.compute_prox_fn))
            # The last action is all 0.
            use_actions = torch.cat([actions, torch.zeros(1, *actions.shape[1:])], dim=0)
            data.extend(zip(states, proxs, use_actions, next_states))
        return data


    def _generate_trajectories(self, trajs):
        is_tensor_dict = not isinstance(trajs["obs"], torch.Tensor)
        if not is_tensor_dict:
            trajs = convert_to_tensors(trajs)

        # Get by trajectory instead of transition
        if is_tensor_dict:
            for name in ["obs", "next_obs"]:
                for k in trajs[name]:
                    trajs[name][k] = trajs["obs"][k].float()
            obs = rutils.transpose_dict_arr(trajs["obs"])
            next_obs = rutils.transpose_dict_arr(trajs["next_obs"])
        else:
            obs = trajs["obs"].float()
            next_obs = trajs["next_obs"].float()

        done = trajs["done"].float()
        actions = trajs["actions"].float()

        ret_trajs = []

        num_samples = done.shape[0]
        print("Collecting trajectories")
        start_j = 0
        j = 0
        while j < num_samples:
            if self.should_terminate_traj(j, obs, next_obs, done, actions):
                obs_seq = obs[start_j : j + 1]
                final_obs = next_obs[j]

                combined_obs = [*obs_seq, final_obs]
                # combined_obs = torch.cat([obs_seq, final_obs.view(1, *obs_dim)])

                next_obs_seq = next_obs[start_j : j + 1]
                final_next_obs = next_obs[j]
                combined_next_obs = [*next_obs_seq, final_next_obs]

                ret_trajs.append((combined_obs, actions[start_j : j + 1], combined_next_obs))
                # Move to where this episode ends
                while j < num_samples and not done[j]:
                    j += 1
                start_j = j + 1

            if j < num_samples and done[j]:
                start_j = j + 1

            j += 1

        for i in range(len(ret_trajs)):
            states, actions, next_states = ret_trajs[i]
            if is_tensor_dict:
                states = rutils.transpose_arr_dict(states)
            else:
                states = torch.stack(states, dim=0)

            if is_tensor_dict:
                next_states = rutils.transpose_arr_dict(next_states)
            else:
                next_states = torch.stack(next_states, dim=0)
            ret_trajs[i] = (states, actions, next_states)

        ret_trajs = self._transform_dem_dataset_fn(ret_trajs, trajs)
        return ret_trajs



    def __getitem__(self, i):
        return {
                'state': self.data[i][0],
                'prox': self.data[i][1],
                'actions': self.data[i][2],
                'next_state': self.data[i][3]
                }

