import pickle

import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset

from bandit.corrupt import get_corrupt_params
from utils import convert_to_tensor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Dataset(TorchDataset):
    """Dataset class."""

    def __init__(self, path, config):
        self.shuffle = config["shuffle"]
        self.horizon = config["H"]
        self.store_gpu = config["store_gpu"]
        self.config = config

        # if path is not a list
        if not isinstance(path, list):
            path = [path]

        self.trajs = []
        for p in path:
            with open(p + ".pkl", "rb") as f:
                print(p + ".pkl")
                self.trajs += pickle.load(f)

        context_states = []
        context_actions = []
        context_next_states = []
        context_rewards = []
        query_states = []
        optimal_actions = []

        for traj in self.trajs:
            context_states.append(traj["context_states"])
            context_actions.append(traj["context_actions"])
            context_next_states.append(traj["context_next_states"])
            context_rewards.append(traj["context_rewards"])

            query_states.append(traj["query_state"])
            optimal_actions.append(traj["optimal_action"])

        context_states = np.array(context_states)
        context_actions = np.array(context_actions)
        context_next_states = np.array(context_next_states)
        context_rewards = np.array(context_rewards)
        if len(context_rewards.shape) < 3:
            context_rewards = context_rewards[:, :, None]
        query_states = np.array(query_states)
        optimal_actions = np.array(optimal_actions)

        self.dataset = {
            "query_states": convert_to_tensor(query_states, store_gpu=self.store_gpu),
            "optimal_actions": convert_to_tensor(optimal_actions, store_gpu=self.store_gpu),
            "context_states": convert_to_tensor(context_states, store_gpu=self.store_gpu),
            "context_actions": convert_to_tensor(context_actions, store_gpu=self.store_gpu),
            "context_next_states": convert_to_tensor(context_next_states, store_gpu=self.store_gpu),
            "context_rewards": convert_to_tensor(context_rewards, store_gpu=self.store_gpu),
        }

        self.zeros = np.zeros(config["state_dim"] ** 2 + config["action_dim"] + 1)
        self.zeros = convert_to_tensor(self.zeros, store_gpu=self.store_gpu)

    def __len__(self) -> int:
        "Denotes the total number of samples"
        return len(self.dataset["query_states"])

    def __getitem__(self, index, *, return_perm=False) -> tuple[Tensor, ...]:
        "Generates one sample of data"
        if self.shuffle:
            perm = torch.randperm(self.horizon)
        else:
            perm = torch.arange(self.horizon)

        c_states = self.dataset["context_states"][index][perm]
        c_actions = self.dataset["context_actions"][index][perm]
        c_next_states = self.dataset["context_next_states"][index][perm]
        c_rewards = self.dataset["context_rewards"][index][perm]
        query_states = self.dataset["query_states"][index][None, :]
        optimal_actions = self.dataset["optimal_actions"][index]

        ctx = torch.cat((c_states, c_actions, c_next_states, c_rewards), dim=1)
        query_line = torch.zeros((1, ctx.shape[-1]), device=device)
        query_line[:, : query_states.shape[-1]] = query_states
        x = torch.cat((query_line, ctx), dim=0)

        if return_perm:
            return x, optimal_actions, perm

        return x, optimal_actions


class CorruptedBanditDataset(Dataset):
    def __init__(self, path, config):
        super().__init__(path, config)
        self.var = config["var"]

        means = []
        for traj in self.trajs:
            means.append(traj["means"])
        means = np.array(means)

        self.corrupt_train = config["corrupt_train"]
        (self.corrupt_type, self.corrupted_steps, self.corrupt_magnitude, _, corrupted_means) = get_corrupt_params(config["corrupt_train"], means, len(self))
        self.corrupted_means = convert_to_tensor(corrupted_means, store_gpu=self.store_gpu)

        self.mask = torch.zeros(len(self), dtype=torch.bool)
        self.mask[self.corrupted_steps] = 1

    def __getitem__(self, index) -> tuple[Tensor, Tensor, Tensor]:
        x, optimal_actions, perm = super().__getitem__(index, return_perm=True)

        c_actions = self.dataset["context_actions"][index][perm]
        c_rewards = self.dataset["context_rewards"][index][perm]

        # Corrupt context
        if self.corrupt_type == "gaussian":
            noise = torch.randn_like(c_rewards) * self.corrupt_magnitude
            mask = self.mask[index]
            c_rewards_corrupted = c_rewards + noise * mask
        elif self.corrupt_type.startswith("change") or self.corrupt_type.startswith("special"):
            mask = self.mask[index]
            if mask:
                c_rewards_corrupted = self.corrupted_means[index][c_actions.argmax(dim=1), None] + torch.randn_like(c_rewards) * self.var
            else:
                c_rewards_corrupted = c_rewards
        else:
            raise RuntimeError("Unsupported corrupt type")

        x_corrupted = x.clone().detach()
        x_corrupted[1:, -1] = c_rewards_corrupted[:, 0]

        return x, optimal_actions, x_corrupted


class ImageDataset(Dataset):
    def __init__(self, paths, config, transform):
        config["store_gpu"] = False
        super().__init__(paths, config)
        self.transform = transform
        self.config = config

        context_filepaths = []
        query_images = []

        for traj in self.trajs:
            context_filepaths.append(traj["context_images"])
            query_image = self.transform(traj["query_image"]).float()
            query_images.append(query_image)

        self.dataset.update(
            {
                "context_filepaths": context_filepaths,  # type: ignore[arg-type]
                "query_images": torch.stack(query_images),
            }
        )

    def __getitem__(self, index) -> tuple[tuple[Tensor, Tensor], Tensor]:
        filepath: str = self.dataset["context_filepaths"][index]  # type: ignore[arg-type]
        context_images = np.load(filepath)
        context_images = [self.transform(images) for images in context_images]
        context_images = torch.stack(context_images).float()

        query_images = self.dataset["query_images"][index][None, :]

        if self.shuffle:
            perm = torch.randperm(self.horizon)
        else:
            perm = torch.arange(self.horizon)

        c_images = context_images[perm]
        c_states = self.dataset["context_states"][index][perm]
        c_actions = self.dataset["context_actions"][index][perm]
        c_rewards = self.dataset["context_rewards"][index][perm]
        if len(c_rewards.shape) == 1:
            c_rewards = c_rewards[:, None]
        query_states = self.dataset["query_states"][index][None, :]
        optimal_actions = self.dataset["optimal_actions"][index]

        x_images = torch.cat([query_images, c_images], dim=0)

        ctx = torch.cat((c_states, c_actions, c_rewards), dim=1)
        query_line = torch.zeros((1, ctx.shape[-1]))
        query_line[:, : query_states.shape[-1]] = query_states
        x = torch.cat((query_line, ctx), dim=0)

        return (x_images, x), optimal_actions
