from collections import defaultdict
from time import time

import numpy as np

from utils.pytorch import random_crop


def make_buffer(shapes, buffer_size):
    buffer = {}
    for k, v in shapes.items():
        if isinstance(v, dict):
            buffer[k] = make_buffer(v, buffer_size)
        else:
            if len(v) >= 3:
                buffer[k] = np.empty((buffer_size, *v), dtype=np.uint8)
            else:
                buffer[k] = np.empty((buffer_size, *v), dtype=np.float32)
    return buffer


def add_rollout(buffer, rollout, idx: int):
    if isinstance(rollout, list):
        rollout = rollout[0]

    if isinstance(rollout, dict):
        for k in rollout.keys():
            add_rollout(buffer[k], rollout[k], idx)
    else:
        np.copyto(buffer[idx], rollout)


def get_batch(buffer: dict, idxs):
    batch = {}
    for k in buffer.keys():
        if isinstance(buffer[k], dict):
            batch[k] = get_batch(buffer[k], idxs)
        else:
            batch[k] = buffer[k][idxs]
    return batch


def augment_ob(batch, image_crop_size):
    for k, v in batch.items():
        if isinstance(batch[k], dict):
            augment_ob(batch[k], image_crop_size)
        elif len(batch[k].shape) > 3:
            batch[k] = random_crop(batch[k], image_crop_size)


class ReplayBufferPerStep(object):
    def __init__(self, shapes: dict, buffer_size: int, image_crop_size=84, absorbing_state=False):
        self._capacity = buffer_size

        if absorbing_state:
            shapes["ob"]["absorbing_state"] = [1]
            shapes["ob_next"]["absorbing_state"] = [1]

        self._shapes = shapes
        self._keys = list(shapes.keys())
        self._image_crop_size = image_crop_size
        self._absorbing_state = absorbing_state

        self._buffer = make_buffer(shapes, buffer_size)
        self._idx = 0
        self._full = False

    def clear(self):
        self._idx = 0
        self._full = False

    # store the episode
    def store_episode(self, rollout):
        for k in self._keys:
            add_rollout(self._buffer[k], rollout[k], self._idx)

        self._idx = (self._idx + 1) % self._capacity
        self._full = self._full or self._idx == 0

    # sample the data from the replay buffer
    def sample(self, batch_size):
        idxs = np.random.randint(
            0, self._capacity if self._full else self._idx, size=batch_size
        )
        batch = get_batch(self._buffer, idxs)

        # apply random crop to image
        augment_ob(batch, self._image_crop_size)

        return batch

    def state_dict(self):
        return {"buffer": self._buffer, "idx": self._idx, "full": self._full}

    def load_state_dict(self, state_dict):
        self._buffer = state_dict["buffer"]
        self._idx = state_dict["idx"]
        self._full = state_dict["full"]

### Buffer for storing online collected transitions
class ReplayBuffer(object):
    def __init__(self, keys, buffer_size, sample_func):
        self._capacity = buffer_size
        self._sample_func = sample_func

        # create the buffer to store info
        self._keys = keys
        self.clear()

    def clear(self):
        self._idx = 0
        self._current_size = 0
        self._buffer = defaultdict(list)

    # store transitions
    def store_episode(self, rollout):
        # @rollout can be any length of transitions
        for k in self._keys:
            if self._current_size < self._capacity:
                self._buffer[k].append(rollout[k])
            else:
                self._buffer[k][self._idx] = rollout[k]

        self._idx = (self._idx + 1) % self._capacity
        if self._current_size < self._capacity:
            self._current_size += 1

    # sample the data from the replay buffer
    def sample(self, batch_size):
        # sample transitions
        transitions = self._sample_func(self._buffer, batch_size)
        return transitions

    def state_dict(self):
        return self._buffer

    def load_state_dict(self, state_dict):
        self._buffer = state_dict
        self._current_size = len(self._buffer["ac"])
    
### Buffer for reach and unreach states
class DataReplayBuffer(object):
    # replay buffer that stores transitions instead of full trajectories

    def __init__(self, keys, buffer_size, sample_func, is_reachable_buffer=False):
        self._capacity = buffer_size
        self._sample_func = sample_func

        # create the buffer to store info
        self._keys = keys
        self.clear()

        # once reachable dataset fills, make sure expert demos don’t get overwritten
        self._expert_demo = False
        self._expert_demo_size = 0
        self._is_reachable = is_reachable_buffer


    def clear(self):
        self._idx = 0
        self._current_size = 0
        self._buffer = defaultdict(list)

    def store_transitions(self, transitions):
        if isinstance(transitions, list):
            if self._expert_demo==False:
                self._expert_demo = True
                self._expert_demo_size = len(transitions)
            
            for transition in transitions:
                self.store_transitions(transition)
            return
        if isinstance(transitions, dict):
            if isinstance(transitions["ac"], list):
                size = len(transitions["ac"])
                if size > 1:
                    for i in range(size):
                        transition = {}
                        for k in self._keys:
                            transition[k] = transitions[k][i]
                        self.store_transitions(transition)

                        if self._current_size >= self._capacity:
                            self._buffer["traj_start_indicator"][self._idx] = [1]
                    return
                
        for k in self._keys:
            if not isinstance(transitions[k], list):
                transitions[k] = [transitions[k]]
            if self._current_size < self._capacity:
                self._buffer[k].append(transitions[k])
            else:
                self._buffer[k][self._idx] = transitions[k]

        self._idx = (self._idx + 1) % self._capacity
        
        if self._current_size < self._capacity:        
            self._current_size += 1
        if self._current_size >= self._capacity:
            if self._is_reachable and self._idx < self._expert_demo_size:
                self._idx = self._expert_demo_size + self._idx
            
    
    # sample the data from the replay buffer
    def sample(self, batch_size):
        # sample transitions
        transitions, episode_indx, t_samples = self._sample_func(self._buffer, batch_size, return_episode_indx=True)
        return transitions, episode_indx, t_samples

    def state_dict(self):
        return self._buffer

    def load_state_dict(self, state_dict):
        self._buffer = state_dict
        self._current_size = len(self._buffer["ac"])
    
    def update_current_size(self):
        self._current_size = len(self._buffer["ac"])

class ReplayBufferEpisode(object):
    def __init__(self, keys, buffer_size, sample_func):
        self._capacity = buffer_size
        self._sample_func = sample_func

        # create the buffer to store info
        self._keys = keys
        self.clear()

    def clear(self):
        self._idx = 0
        self._current_size = 0
        self._new_episode = True
        self._buffer = defaultdict(list)

    # store the episode
    def store_episode(self, rollout):
        if self._new_episode:
            self._new_episode = False
            for k in self._keys:
                if self._current_size < self._capacity:
                    self._buffer[k].append(rollout[k])
                else:
                    self._buffer[k][self._idx] = rollout[k]
        else:
            for k in self._keys:
                self._buffer[k][self._idx].extend(rollout[k])

        if rollout["done"][-1]:
            self._idx = (self._idx + 1) % self._capacity
            if self._current_size < self._capacity:
                self._current_size += 1
            self._new_episode = True

    # sample the data from the replay buffer
    def sample(self, batch_size):
        # sample transitions
        transitions = self._sample_func(self._buffer, batch_size)
        return transitions

    def state_dict(self):
        return self._buffer

    def load_state_dict(self, state_dict):
        self._buffer = state_dict
        self._current_size = len(self._buffer["ac"])


class RandomSampler(object):
    def __init__(self, image_crop_size=84):
        self._image_crop_size = image_crop_size

    def sample_func(self, episode_batch, batch_size_in_transitions, return_episode_indx=False):
        rollout_batch_size = len(episode_batch["ac"])
        batch_size = batch_size_in_transitions

        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size) #randomly select batch_size number of episodes
        t_samples = [ #randomly select a time step from each episode
            np.random.randint(len(episode_batch["ac"][episode_idx]))
            for episode_idx in episode_idxs
        ]

        transitions = {}
        for key in episode_batch.keys():
            transitions[key] = [
                episode_batch[key][episode_idx][t]
                for episode_idx, t in zip(episode_idxs, t_samples)
            ]

        transitions["ob_next"] = [
            episode_batch["ob_next"][episode_idx][t]
            for episode_idx, t in zip(episode_idxs, t_samples)
        ]
        # dict-list --> dict-array
        new_transitions = {}
        for k, v in transitions.items():
            try:
                if isinstance(v[0], dict):
                    sub_keys = v[0].keys()
                    new_transitions[k] = {
                        sub_key: np.stack([v_[sub_key] for v_ in v]) for sub_key in sub_keys
                    }
                else:
                    new_transitions[k] = np.stack(v)
            except:
                print("all data should have the same shape", k)
                raise

        for k, v in new_transitions["ob"].items():
            if len(v.shape) in [4, 5]:
                new_transitions["ob"][k] = random_crop(v, self._image_crop_size)

        for k, v in new_transitions["ob_next"].items():
            if len(v.shape) in [4, 5]:
                new_transitions["ob_next"][k] = random_crop(v, self._image_crop_size)
        
        if return_episode_indx:
            return new_transitions, episode_idxs, t_samples
        else:
            return new_transitions

class SeqSampler(object):
    def __init__(self, seq_length, image_crop_size=84):
        self._seq_length = seq_length
        self._image_crop_size = image_crop_size

    def sample_func(self, episode_batch, batch_size_in_transitions):
        rollout_batch_size = len(episode_batch["ac"])
        batch_size = batch_size_in_transitions

        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = [
            np.random.randint(len(episode_batch["ac"][episode_idx]))
            for episode_idx in episode_idxs
        ]

        transitions = {}
        for key in episode_batch.keys():
            transitions[key] = [
                episode_batch[key][episode_idx][t]
                for episode_idx, t in zip(episode_idxs, t_samples)
            ]

        transitions["ob_next"] = [
            episode_batch["ob_next"][episode_idx][t]
            for episode_idx, t in zip(episode_idxs, t_samples)
        ]

        #Create a key that stores the specified future fixed length of sequences, pad last states if necessary

        print(episode_idxs)
        print(t_samples)

        #List of dictionaries is created here..., flatten it out?
        transitions["following_sequences"] = [
            episode_batch["ob"][episode_idx][t:t+ self._seq_length]
            for episode_idx, t in zip(episode_idxs, t_samples)
        ]

        #something's wrong here... should use index episode_idx to episode_batch, not transitions

        # # Pad last states
        # for episode_idx in episode_idxs:
        #     # curr_ep = episode_batch["ob"][episode_idx]
        #     # curr_ep.extend(curr_ep[-1:] * (self._seq_length - len(curr_ep)))
        #
        #     #all list should have 10 dictionaries now
        #     if isinstance(transitions["following_sequences"][episode_idx], dict):
        #         continue
        #     transitions["following_sequences"][episode_idx].extend(transitions["following_sequences"][episode_idx][-1:] * (self._seq_length - len(transitions["following_sequences"][episode_idx])))
        #
        #     #turn transitions["following_sequences"] to a dictionary
        #     fs_list = transitions["following_sequences"][episode_idx]
        #     container = {}
        #     container["ob"] = []
        #     for i in fs_list:
        #         container["ob"].extend(i["ob"])
        #     container["ob"] = np.array(container["ob"])
        #     transitions["following_sequences"][episode_idx] = container

        # Pad last states
        for i in range(len(transitions["following_sequences"])):
            # curr_ep = episode_batch["ob"][episode_idx]
            # curr_ep.extend(curr_ep[-1:] * (self._seq_length - len(curr_ep)))

            #all list should have 10 dictionaries now
            if isinstance(transitions["following_sequences"][i], dict):
                continue
            transitions["following_sequences"][i].extend(transitions["following_sequences"][i][-1:] * (self._seq_length - len(transitions["following_sequences"][i])))

            #turn transitions["following_sequences"] to a dictionary
            fs_list = transitions["following_sequences"][i]
            container = {}
            container["ob"] = []
            for j in fs_list:
                container["ob"].extend(j["ob"])
            container["ob"] = np.array(container["ob"])
            transitions["following_sequences"][i] = container



        new_transitions = {}
        for k, v in transitions.items():
            if isinstance(v[0], dict):
                sub_keys = v[0].keys()
                new_transitions[k] = {
                    sub_key: np.stack([v_[sub_key] for v_ in v]) for sub_key in sub_keys
                }
            else:
                new_transitions[k] = np.stack(v)

        for k, v in new_transitions["ob"].items():
            if len(v.shape) in [4, 5]:
                new_transitions["ob"][k] = random_crop(v, self._image_crop_size)

        for k, v in new_transitions["ob_next"].items():
            if len(v.shape) in [4, 5]:
                new_transitions["ob_next"][k] = random_crop(v, self._image_crop_size)

        return new_transitions


class HERSampler(object):
    def __init__(self, replay_strategy, replace_future, reward_func=None):
        self.replay_strategy = replay_strategy
        if self.replay_strategy == "future":
            self.future_p = replace_future
        else:
            self.future_p = 0
        self.reward_func = reward_func

    def sample_her_transitions(self, episode_batch, batch_size_in_transitions):
        rollout_batch_size = len(episode_batch["ac"])
        batch_size = batch_size_in_transitions

        # select which rollouts and which timesteps to be used
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = [
            np.random.randint(len(episode_batch["ac"][episode_idx]))
            for episode_idx in episode_idxs
        ]

        transitions = {}
        for key in episode_batch.keys():
            transitions[key] = [
                episode_batch[key][episode_idx][t]
                for episode_idx, t in zip(episode_idxs, t_samples)
            ]

        transitions["ob_next"] = [
            episode_batch["ob"][episode_idx][t + 1]
            for episode_idx, t in zip(episode_idxs, t_samples)
        ]
        transitions["r"] = np.zeros((batch_size,))

        # hindsight experience replay
        for i, (episode_idx, t) in enumerate(zip(episode_idxs, t_samples)):
            replace_goal = np.random.uniform() < self.future_p
            if replace_goal:
                future_t = np.random.randint(
                    t + 1, len(episode_batch["ac"][episode_idx]) + 1
                )
                future_ag = episode_batch["ag"][episode_idx][future_t]
                if (
                    self.reward_func(
                        episode_batch["ag"][episode_idx][t], future_ag, None
                    )
                    < 0
                ):
                    transitions["g"][i] = future_ag
            transitions["r"][i] = self.reward_func(
                episode_batch["ag"][episode_idx][t + 1], transitions["g"][i], None
            )

        new_transitions = {}
        for k, v in transitions.items():
            if isinstance(v[0], dict):
                sub_keys = v[0].keys()
                new_transitions[k] = {
                    sub_key: np.stack([v_[sub_key] for v_ in v]) for sub_key in sub_keys
                }
            else:
                new_transitions[k] = np.stack(v)

        return new_transitions

