from typing import List
from collections import defaultdict

import numpy as np
import torch
import tqdm

from torch.utils.data.dataset import Dataset


class TrajectoryDataset(Dataset):
    def __init__(
        self,
        trajectories: List[dict],
        state_dim: int,
        act_dim: int,
        K: int = 4,
        batch_size: int = 128,
        use_kl:bool=False
    ):
        super().__init__()
        self.use_kl = use_kl
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.K = K

        self._batch_size = batch_size

        self.states = np.zeros((0, *self.state_dim), dtype=np.float32)
        self.actions_prob = np.zeros((0, self.act_dim), dtype=np.float32)
        self.actions = np.zeros((0,), dtype=np.int64)
        self.rewards = np.zeros((0,), dtype=np.float32)
        self.dones = np.zeros((0,), dtype=np.int64)
        self.real_dones = np.zeros((0,), dtype=np.int64)
        self.current_rooms = np.zeros((0,), dtype=np.int64)
        self.mask = np.zeros((0,), dtype=np.int64)

        self.traj_lens = np.array([], dtype=np.int64)
        self.padded_traj_lens = np.array([], dtype=np.int64)
        self.traj_begs = np.array([], dtype=np.int64)
        self.p_sample = np.array([], dtype=np.float64)

        self.add_trajectories(trajectories)

    def process_trajectory(self, t):
        tlen = len(t["reward"])

        states = np.concatenate([np.zeros((self.K - 1, *self.state_dim)), t["state"]])
        actions = np.concatenate([np.ones(self.K - 1) * -10, t["action"]], 0)
        rewards = np.concatenate([np.zeros(self.K - 1), t["reward"]])
        dones = np.concatenate([np.zeros(self.K - 1) * 2.0, t["done"]])
        real_dones = np.concatenate([np.zeros(self.K - 1) * 2.0, t["real_done"]])
        current_rooms = np.concatenate([np.zeros(self.K - 1), t["current_room"]])
        mask = np.concatenate([np.zeros(self.K - 1), np.ones(tlen)])

        actions_probs = None
        if self.use_kl:
            actions_probs = np.concatenate([np.ones((self.K - 1, self.act_dim)) * -10, t["action_prob"]], 0)
        return (
            states.astype(np.float32),
            actions.astype(np.int64),
            rewards.astype(np.float32),
            dones.astype(np.int64),
            real_dones.astype(np.int64),
            current_rooms.astype(np.int64),
            mask.astype(np.int64),
            actions_probs.astype(np.float32),
        )

    def add_trajectories(self, new_trajectories):
        # Store intermediate results in lists
        states, actions, rewards, dones, real_dones = [], [], [], [], []
        current_rooms, masks, traj_lens, padded_traj_lens = [], [], [], []
        actions_prob = [] if self.use_kl else None

        for trajectory in tqdm.tqdm(new_trajectories, desc="Processing trajectories", total=len(new_trajectories)):
            t = defaultdict(list)
            for d in trajectory:
                for key, value in d.items():
                    t[key].append(value)
            trajectory = {k: np.array(v, copy=False) for k, v in t.items()}
            if len(trajectory["reward"]) < 10:
                print(f"Skipping trajectory of length: {len(trajectory['reward'])}")
                continue
            s, a, r, d, rd, cr, m, a_prob = self.process_trajectory(trajectory)
            states.append(s)
            actions.append(a)
            rewards.append(r)
            dones.append(d)
            real_dones.append(rd)
            current_rooms.append(cr)
            masks.append(m)
            if self.use_kl:
                actions_prob.append(a_prob)
            traj_lens.append(len(trajectory["reward"]))
            padded_traj_lens.append(len(r))

        # Convert lists to numpy arrays
        self.states = np.concatenate(states)
        self.actions = np.concatenate(actions)
        self.rewards = np.concatenate(rewards)
        self.dones = np.concatenate(dones)
        self.real_dones = np.concatenate(real_dones)
        self.current_rooms = np.concatenate(current_rooms)
        self.mask = np.concatenate(masks)

        if self.use_kl:
            self.actions_prob = np.concatenate(actions_prob)

        self.traj_lens = np.array(traj_lens)
        self.padded_traj_lens = np.array(padded_traj_lens)
        self.traj_begs = np.cumsum(np.insert(self.padded_traj_lens[:-1], 0, 0))
        self.p_sample = self.traj_lens / sum(self.traj_lens)

    def __len__(self):
        return len(self.traj_begs)

    def __getitem__(self, idx):
        idx = np.array(idx)

        offset_idx = self.traj_begs[idx]
        valid_idx = self.traj_lens[idx] - 1
        si = np.random.randint(0, valid_idx, size=idx.size) + offset_idx

        # state indexes are larger because we need to extract next_states
        state_indexes = np.array(list(np.arange(i, i + self.K + 1) for i in si)).astype(
            int
        )
        indexes = si + self.K - 1

        states = self.states[state_indexes]
        actions = self.actions[indexes]
        rewards = self.rewards[indexes]
        dones = self.dones[indexes]
        real_dones = self.real_dones[indexes]
        current_rooms = self.current_rooms[indexes]
        mask = self.mask[indexes]

        actions_prob = None
        if self.use_kl:
            actions_prob = self.actions_prob[indexes]

        states = states / 255.

        next_states = states[:, 1:]
        states = states[:, :-1]
        return {
            "states": states,
            "actions": actions,
            "rewards": rewards,
            "next_states": next_states,
            "dones": dones,
            "real_dones": real_dones,
            "current_rooms": current_rooms,
            "mask": mask,
            "actions_prob": actions_prob
        }

    def sample_batch(self, batch_size: int = None):
        assert len(self) > 0, "tried to sample batch from empty dataset"

        if batch_size is None:
            batch_size = self._batch_size

        idxs = np.random.choice(
            np.arange(len(self)),
            size=batch_size,
            replace=True,
            p=self.p_sample,  # reweights so we sample according to timesteps
        )
        batch = self[idxs]
        return {k: torch.as_tensor(v) for k, v in batch.items()}
