import random
import numpy as np

import glob
from collections import defaultdict
from torch.utils.data import IterableDataset, Dataset
from typing import List, Dict, Any


def pad_along_axis(arr, pad_to, axis=0, fill_value=0.0):
    pad_size = pad_to - arr.shape[axis]
    if pad_size <= 0:
        return arr

    npad = [(0, 0)] * arr.ndim
    npad[axis] = (0, pad_size)
    return np.pad(arr, pad_width=npad, mode="constant", constant_values=fill_value)


def load_learning_histories(path: str) -> List[Dict[str, Any]]:
    files = glob.glob(f"{path}/*.npz")

    learning_histories = []
    for filename in files:
        with np.load(filename, allow_pickle=True) as f:
            learning_histories.append({
                "states": f["states"],
                "actions": f["actions"],
                "rewards": f["rewards"],
                "dones": f["dones"],
                "goal": f["goal"],
                "returns": f["returns"],
            })
            # print(f["actions"])

    return learning_histories


def split_to_episodes(learning_history):
    trajectories = []

    traj_data = defaultdict(list)
    cur_step = 0
    global_step = 0
    cur_discount = 1.0
    cur_return = 0.0
    for step in range(len(learning_history["dones"])):
        # append data
        traj_data["states"].append(learning_history["states"][step])
        traj_data["actions"].append(learning_history["actions"][step])
        traj_data["rewards"].append(learning_history["rewards"][step])
        traj_data["dones"].append(learning_history["dones"][step])
        traj_data["steps"].append(cur_step)
        traj_data["global_steps"].append(global_step)
        cur_return += learning_history["rewards"][step] * cur_discount
        cur_step += 1
        global_step += 1
        cur_discount *= 0.99

        if learning_history["dones"][step]:
            traj_data["return"] = cur_return
            trajectories.append({k: np.array(v) for k, v in traj_data.items()})
            traj_data = defaultdict(list)
            cur_step = 0
            cur_discount = 1.0
            cur_return = 0

    return trajectories


def subsample_history(learning_history, subsample, random_order=False, sorted_order=False, keep_split=False):
    trajectories = split_to_episodes(learning_history)
    order = [i for i in range(len(trajectories))]
    if random_order:
        random.shuffle(order)
    if sorted_order:
        returns = np.array([traj["return"] for traj in trajectories]).reshape(-1)
        order = np.argsort(returns)
    subsampled_trajectories = np.array(trajectories)[order][::subsample]

    subsampled_history = {
        "states": np.concatenate([traj["states"] for traj in subsampled_trajectories]),
        "actions": np.concatenate([traj["actions"] for traj in subsampled_trajectories]),
        "rewards": np.concatenate([traj["rewards"] for traj in subsampled_trajectories]),
        "dones": np.concatenate([traj["dones"] for traj in subsampled_trajectories]),
        "steps": np.concatenate([traj["steps"] for traj in subsampled_trajectories]),
        "global_steps": np.concatenate([traj["global_steps"] for traj in subsampled_trajectories]),
        "goal": learning_history["goal"],
        "returns": np.concatenate([traj["return"] for traj in subsampled_trajectories]),
    }
    if not keep_split:
        return subsampled_history
    else:
        subsampled_history_split = {
            "states": [traj["states"] for traj in subsampled_trajectories],
            "actions": [traj["actions"] for traj in subsampled_trajectories],
            "rewards": [traj["rewards"] for traj in subsampled_trajectories],
            "dones": [traj["dones"] for traj in subsampled_trajectories],
            "steps": [traj["steps"] for traj in subsampled_trajectories],
            "global_steps": [traj["global_steps"] for traj in subsampled_trajectories],
            "returns": [traj["return"] for traj in subsampled_trajectories],
        }
        return subsampled_history, subsampled_history_split


class VanillaIterableDataset(IterableDataset):
    def __init__(self, data_path: str, seq_len: int = 60, subsample: int = 1):
        self.seq_len = seq_len
        print("Loading training histories...")
        self.histories = load_learning_histories(data_path)
        print("Num histories:", len(self.histories))

        self.goals = np.vstack([trajectory["goal"] for trajectory in self.histories])
        self.unique_goals = np.unique(self.goals.reshape(-1, 4), axis=0)

        if subsample > 1:
            self.histories = [subsample_history(hist, subsample) for hist in self.histories]

    def __prepare_sample(self, history_idx, start_idx):
        history = self.histories[history_idx]

        assert history["states"].shape[0] == history["actions"].shape[0] == history["rewards"].shape[0]
        states = history["states"][start_idx : start_idx + self.seq_len].flatten()
        actions = history["actions"][start_idx : start_idx + self.seq_len].flatten()
        rewards = history["rewards"][start_idx : start_idx + self.seq_len].flatten()
        assert states.shape[0] == self.seq_len

        return states, actions, rewards

    def __iter__(self):
        while True:
            history_idx = random.randint(0, len(self.histories) - 1)
            start_idx = random.randint(0, self.histories[history_idx]["rewards"].shape[0] - self.seq_len - 1)
            yield self.__prepare_sample(history_idx, start_idx)


def flatten_except_first(arr):
    # Flatten across all axes except the first
    flattened = arr.reshape(arr.shape[0], -1)

    # If the resulting flattened shape has only one dimension after flattening, flatten it completely
    if flattened.shape[1] == 1:
        return flattened.ravel()  # Flatten completely if it's essentially 1D
    return flattened


class VanillaMapDataset(Dataset):
    def __init__(self, data_path: str, seq_len: int = 60, subsample: int = 1, goal_dim: int = 4):
        self.seq_len = seq_len
        print("Loading training histories...")
        self.histories = load_learning_histories(data_path)
        print("Num histories:", len(self.histories))

        self.goals = np.vstack([trajectory["goal"] for trajectory in self.histories])
        self.unique_goals = np.unique(self.goals.reshape(-1, goal_dim), axis=0)

        if subsample > 1:
            self.histories = [subsample_history(hist, subsample) for hist in self.histories]

        # precompute all the slices
        self.slices = []
        for hist_idx, hist in enumerate(self.histories):
            for start_idx in range(0, len(hist["states"]) - self.seq_len):
                self.slices.append((hist_idx, start_idx))

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, idx):
        (history_idx, start_idx) = self.slices[idx]

        history = self.histories[history_idx]
        assert history["states"].shape[0] == history["actions"].shape[0] == history["rewards"].shape[0]
        # print("SHAPES", history["states"].shape, history["actions"].shape, history["rewards"].shape, self.seq_len, flush=True)
        # sampling state, actions, rewards
        states = flatten_except_first(history["states"][start_idx:start_idx + self.seq_len])
        actions = flatten_except_first(history["actions"][start_idx:start_idx + self.seq_len])
        rewards = history["rewards"][start_idx:start_idx + self.seq_len].flatten()
        dones = history["dones"][start_idx:start_idx + self.seq_len].flatten()
        # print("SHAPES", states.shape, actions.shape, rewards.shape, self.seq_len, flush=True)
        assert states.shape[0] == actions.shape[0] == rewards.shape[0] == self.seq_len
        mask = np.ones_like(rewards)

        return states, actions, rewards, dones, mask


class TrialsMapDataset(Dataset):
    def __init__(self, data_path: str, seq_len: int = 60, size: int = 1000000, trials_split: int = None, goal_dim: int = 4):
        self.seq_len = seq_len
        self.trials_split = trials_split
        print("Loading training histories...")
        self.histories = load_learning_histories(data_path)
        self.scaling_coef = len(self.histories) / 5000
        self.size = int(size * self.scaling_coef)
        print("Num histories:", len(self.histories))

        self.goals = np.vstack([trajectory["goal"] for trajectory in self.histories])
        self.unique_goals = np.unique(self.goals.reshape(-1, goal_dim), axis=0)

        # history = self.histories[0]
        # print({k: history[k].shape for k in history})
        # trajectories = split_to_episodes(history)
        # print(trajectories)
        # raise ValueError()

        self.histories = [split_to_episodes(hist) for hist in self.histories]
        self.idx2history = np.random.randint(0, len(self.histories), size=self.size)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        history_idx = self.idx2history[idx]

        full_history = self.histories[history_idx]
        indices = np.random.choice(len(full_history), size=self.trials_split, replace=False)
        indices = np.sort(indices)

        history = [full_history[i] for i in indices]

        history = {
            "states": np.concatenate([traj["states"] for traj in history]),
            "actions": np.concatenate([traj["actions"] for traj in history]),
            "rewards": np.concatenate([traj["rewards"] for traj in history]),
            # "dones": np.concatenate([traj["dones"] for traj in history]),
        }

        assert history["states"].shape[0] == history["actions"].shape[0] == history["rewards"].shape[0]

        # sampling state, actions, rewards, rtgs
        states = history["states"].flatten()[:self.seq_len]
        actions = history["actions"].flatten()[:self.seq_len]
        rewards = history["rewards"].flatten()[:self.seq_len]
        # dones = history["dones"].flatten()[:self.seq_len]

        mask = np.array(([1] * states.shape[0] + [0] * (self.seq_len - states.shape[0])))

        pad_width = (0, self.seq_len - states.shape[0])
        states = np.pad(states, pad_width=pad_width, mode="constant", constant_values=0)
        actions = np.pad(actions, pad_width=pad_width, mode="constant", constant_values=0)
        rewards = np.pad(rewards, pad_width=pad_width, mode="constant", constant_values=0)
        # dones = np.pad(dones, pad_width=pad_width, mode="constant", constant_values=0)

        # print(states.shape[0], rtgs.shape[0], actions.shape[0], rewards.shape[0], dones.shape[0], mask.shape[0])
        assert states.shape[0] == actions.shape[0] == rewards.shape[0] == self.seq_len == mask.shape[0]

        return states, actions, rewards, mask


class TuplesIterableDataset(IterableDataset):
    def __init__(self, data_path: str, seq_len: int = 60, subsample: int = 1):
        self.seq_len = seq_len
        print("Loading training histories...")
        self.histories = load_learning_histories(data_path)
        print("Num histories:", len(self.histories))

        self.goals = np.vstack([trajectory["goal"] for trajectory in self.histories])
        self.unique_goals = np.unique(self.goals.reshape(-1, 4), axis=0)

        if subsample > 1:
            self.histories = [subsample_history(hist, subsample) for hist in self.histories]

    def __prepare_sample(self, history_idx, start_idx):
        history = self.histories[history_idx]
        assert history["states"].shape[0] == history["actions"].shape[0] == history["rewards"].shape[0]

        # sampling state, prev_actions, prev_rewards
        states = history["states"][start_idx:start_idx + self.seq_len].flatten()
        prev_actions = history["actions"][start_idx - 1:start_idx - 1 + self.seq_len].flatten()
        prev_rewards = history["rewards"][start_idx - 1:start_idx - 1 + self.seq_len].flatten()
        # target actions to predict from the context
        target_actions = history["actions"][start_idx:start_idx + self.seq_len].flatten()

        assert states.shape[0] == prev_actions.shape[0] == prev_rewards.shape[0] == self.seq_len
        assert target_actions.shape[0] == self.seq_len

        return states, prev_actions, prev_rewards, target_actions

    def __iter__(self):
        while True:
            history_idx = random.randint(0, len(self.histories) - 1)
            # sample in a way to avoid paddings from both sides
            start_idx = random.randint(1, self.histories[history_idx]["rewards"].shape[0] - self.seq_len - 1)
            yield self.__prepare_sample(history_idx, start_idx)


class TuplesMapDataset(Dataset):
    def __init__(self, data_path: str, seq_len: int = 60, subsample: int = 1, goal_dim: int = 4, return_goals: bool = False, uniform_sample: bool = False,
                 random_order=False, sorted_order=False, sample_ordered=False):
        self.sample_ordered = sample_ordered
        self.seq_len = seq_len
        print("Loading training histories...")
        self.histories = load_learning_histories(data_path)
        print("Num histories:", len(self.histories))
        self.return_goals = return_goals
        self.uniform_sample = uniform_sample

        if goal_dim is not None:
            self.goals = np.vstack([trajectory["goal"] for trajectory in self.histories])
            self.unique_goals = np.unique(self.goals.reshape(-1, goal_dim), axis=0)
        self.goal_dim = goal_dim

        # if subsample > 1:
        histories = []
        self.histories_split = []
        # print("FLAG", sample_ordered, flush=True)
        for hist in self.histories:
            if sample_ordered:
                hist_ss, hist_ss_split = subsample_history(hist, subsample, random_order=random_order, sorted_order=sorted_order, keep_split=sample_ordered)
                # print(hist_ss_split)
                self.histories_split.append(hist_ss_split)
            else:
                hist_ss = subsample_history(hist, subsample, random_order=random_order,
                                                           sorted_order=sorted_order, keep_split=sample_ordered)
            histories.append(hist_ss)
        # precompute all the slices
        self.histories = histories
        self.slices = []
        # print(histories)
        for hist_idx, hist in enumerate(self.histories):
            for start_idx in range(1, len(hist["states"]) - self.seq_len):
                self.slices.append((hist_idx, start_idx))

    def __len__(self):
        return len(self.slices)

    def build_sample(self, history, n_episodes=20):
        num_episodes = len(history['states'])
        subset_idx = np.random.choice(np.arange(num_episodes), size=n_episodes, replace=False)
        # print("traj idx", subset_idx, flush=True)
        subsampled_history = {
            k: [history[k][idx] for idx in subset_idx] for k in history
        }
        # print("RETURNS BEFORE", np.array(subsampled_history['returns']).flatten())
        traj_order = np.argsort(np.array(subsampled_history['returns']).flatten())
        # print("traj_order", traj_order)
        # print("SORTED WITH ORDER", np.array(subsampled_history['returns']).flatten()[traj_order])
        for k in subsampled_history:
            subsampled_history[k] = [subsampled_history[k][idx] for idx in traj_order]
        subsampled_history = {
            k: np.concatenate([traj for traj in subsampled_history[k]]) for k in subsampled_history
        }
        # print("RETURNS AFTER", subsampled_history['returns'], flush=True)
        traj_len = subsampled_history['states'].shape[0]
        if traj_len <= self.seq_len:
            # n_episodes += 2
            return None, None
        start_index = np.random.randint(1, traj_len - self.seq_len + 1)
        return subsampled_history, start_index

    def __getitem__(self, idx):
        if self.uniform_sample:
            idx = random.randint(0, len(self.slices) - 1)
        (history_idx, start_idx) = self.slices[idx]
        # print(idx, history_idx, start_idx)

        history = self.histories[history_idx]
        assert history["states"].shape[0] == history["actions"].shape[0] == history["rewards"].shape[0]

        if self.sample_ordered:
            # print(len(self.histories), len(self.histories_split), flush=True)
            while True:
                history, start_idx = self.build_sample(self.histories_split[history_idx])
                if history is not None:
                    break
                history_idx = (history_idx + 1) % len(self.histories)

        # sampling state, prev_actions, prev_rewards
        states = flatten_except_first(history["states"][start_idx:start_idx + self.seq_len])
        steps = history["steps"][start_idx:start_idx + self.seq_len].flatten()
        global_steps = history["global_steps"][start_idx:start_idx + self.seq_len].flatten()
        prev_actions = flatten_except_first(history["actions"][start_idx - 1:start_idx - 1 + self.seq_len])
        prev_rewards = history["rewards"][start_idx - 1:start_idx - 1 + self.seq_len].flatten()
        prev_dones = history["dones"][start_idx - 1:start_idx - 1 + self.seq_len].flatten()
        # target actions to predict from the context
        target_actions = flatten_except_first(history["actions"][start_idx:start_idx + self.seq_len])
        rewards = history["rewards"][start_idx:start_idx + self.seq_len].flatten()
        dones = history["dones"][start_idx:start_idx + self.seq_len].flatten()

        # print(len(history["returns"].flatten()))
        # print("SATES:", flatten_except_first(history["states"][start_idx:start_idx + self.seq_len].flatten()), flush=True)
        # print("RETURNS:", history["returns"].flatten(), flush=True)
        # raise ValueError()

        assert states.shape[0] == prev_actions.shape[0] == prev_rewards.shape[0] == self.seq_len
        assert target_actions.shape[0] == self.seq_len
        if not self.return_goals:
            return states, prev_actions, prev_rewards, prev_dones, target_actions, rewards, dones, steps
        else:
            return states, prev_actions, prev_rewards, prev_dones, target_actions, rewards, dones, steps, np.array([history["goal"].reshape(self.goal_dim)] * self.seq_len), global_steps

