import numpy as np
import torch

from copy import deepcopy
from tqdm import trange
from sklearn.neighbors import KernelDensity
from torch.utils.data import Dataset, DataLoader

INPUT_RANGE = (0, 1)
SAMPLE_MODES = [(0.2, 0.8), (0.9, 0.4)]
SAMPLE_BAND = [0.20, 0.08]
XRANGE = np.linspace(0, 1, num=25)
CTR_MODES = np.vstack((XRANGE, -(XRANGE - 0.5) ** 2 + 0.65)).T


def generate_sample_density():
    density = 0.
    xrange, yrange = np.meshgrid(np.linspace(*INPUT_RANGE, num=100), np.linspace(*INPUT_RANGE, num=100))

    for mode, band in zip(SAMPLE_MODES, SAMPLE_BAND):
        kde = KernelDensity(bandwidth=band, kernel="gaussian")
        kde.fit([mode])
        xy = np.vstack([xrange.ravel(), yrange.ravel()]).T
        density += np.exp(kde.score_samples(xy)).reshape(xrange.shape)
    density /= density.sum()

    return xrange, yrange, density


def generate_ctr_density(lower_bound=0.1, upper_bound=0.9):
    kde = KernelDensity(bandwidth=0.07, kernel="gaussian")
    kde.fit(CTR_MODES)
    xrange, yrange = np.meshgrid(np.linspace(*INPUT_RANGE, num=100), np.linspace(*INPUT_RANGE, num=100))
    xy = np.vstack([xrange.ravel(), yrange.ravel()]).T
    density = np.exp(kde.score_samples(xy)).reshape(xrange.shape)
    density /= density.max()
    density = np.clip(density, lower_bound, upper_bound)

    return xrange, yrange, density


class Sampler:
    def __init__(self, density):
        self.shape = density.shape
        self.indices = np.arange(len(density.flatten()))
        self.prob = density.flatten().ravel().T

    def sample(self, size=10):
        idx = np.random.choice(self.indices, p=self.prob, size=size)
        y = idx // self.shape[0]
        x = idx % self.shape[0]
        r = INPUT_RANGE[1] - INPUT_RANGE[0]
        return np.vstack((x / self.shape[0] * r, y / self.shape[1] * r)).T, np.vstack((y, x))


class Bootstrapper(Dataset):
    def __init__(self, X_data, y_data, prob_data, usage_ratio=1, base_ratio=0.6, seed=None):
        super().__init__()
        assert len(X_data) == len(y_data)
        if prob_data is None:
            prob_data = torch.zeros_like(y_data)

        rng = np.random.default_rng(seed=seed)

        if usage_ratio < 1:
            idx = np.arange(len(X_data))
            rng.shuffle(idx)
            idx = idx[:int(usage_ratio * len(X_data))]

            X_data = X_data[idx]
            y_data = y_data[idx]
            prob_data = prob_data[idx]

        total_length = len(X_data)
        indices = np.zeros(total_length, dtype=bool)
        indices[rng.choice(total_length, size=int(base_ratio * total_length), replace=False)] = True

        self.X_base = X_data[indices]
        self.y_base = y_data[indices]
        self.prob_base = prob_data[indices]

        self.X_boot = X_data[~indices]
        self.y_boot = y_data[~indices]
        self.prob_boot = prob_data[~indices]

    def get_ref_data(self):
        return self.X_base, self.y_base, self.prob_base

    def get_query_data(self):
        return self.X_boot, self.y_boot, self.prob_boot

    def __len__(self):
        return len(self.X_boot)

    def __getitem__(self, item):
        return self.X_boot[item].unsqueeze(dim=0), self.y_boot[item], self.prob_boot[item]


class MultiBootstrapper(Dataset):
    def __init__(self, X_list, y_list, prob_list, usage_ratio=1, base_ratio=0.6, seed=None,
                 batch_size=16, shuffle=True):
        super().__init__()
        assert len(X_list) == len(y_list)
        assert len(X_list[0]) == len(y_list[0])
        # assert all([len(x1) == len(x2) for x1, x2 in zip(X_list[:-1], X_list[1:])])  # all elements in X_list are of the same length

        if prob_list is None:
            prob_list = [np.zeros_like(y) for y in y_list]

        self.block_size = len(X_list[0])
        self.list_len = len(X_list)
        self.total_length = self.block_size * len(X_list)

        self.boots = [Bootstrapper(X_data, y_data, prob_data,
                                   usage_ratio=usage_ratio, base_ratio=base_ratio,
                                   seed=None if seed is None else (seed + i)) for
                      i, (X_data, y_data, prob_data) in
                      enumerate(zip(X_list, y_list, prob_list))]
        self.dls = [DataLoader(ds, batch_size=batch_size, pin_memory=True,
                               pin_memory_device="cuda", shuffle=shuffle) for ds in self.boots]

    def __len__(self):
        return self.list_len

    def __getitem__(self, item):
        return self.dls[item], self.boots[item].get_ref_data()

    def get_dl(self):
        return DataLoader(self, batch_size=1, shuffle=True, collate_fn=lambda x: x[0])


class DataStreamer:
    def __init__(self, X, y):
        self.X = deepcopy(X)
        self.y = deepcopy(y)

        self.X_stream = None
        self.y_stream = None

    def init_multi(self, K=4, seed=None):
        rng = np.random.default_rng(seed)
        train_idx = np.arange(len(self.X))
        rng.shuffle(train_idx)
        train_idx = train_idx[:(len(self.X) // K) * K].reshape(len(self.X) // K, K)

        self.X_stream = self.X[train_idx]
        self.y_stream = self.y[train_idx]

    def __len__(self):
        if self.X_stream is None:
            raise RuntimeError("Streamer is not initialized.")
        return len(self.X_stream)

    def __iter__(self):
        for X, y in zip(self.X_stream, self.y_stream):
            optimal_R = y.max()
            yield X, y, optimal_R


def sample_genrl(env, num_rows, seed=None):
    rng = np.random.RandomState(seed=seed)

    context_dim = env.context_dim
    action_dim = env.n_actions

    I = np.eye(action_dim)
    X_data = np.empty((num_rows, context_dim + action_dim), dtype=np.float32)
    y_data = np.empty((num_rows, 1), dtype=np.float32)

    C = env.reset()
    for i in trange(num_rows):
        A = rng.choice(action_dim)
        new_C, R = env.step(A)

        X_data[i, :context_dim] = C
        X_data[i, context_dim:] = I[A]
        y_data[i] = R

        C = new_C

    return X_data, y_data


class BernoulliBandit:
    def __init__(self, n_arms, prob_range=(0.1, 0.9), seed=None):
        self.n_arms = n_arms
        self.prob_range = prob_range

        self.probs = None
        self.rng = None
        self.best_arm_prob = None

        self.reset(seed=seed)

    def reset(self, seed=None):
        self.rng = np.random.RandomState(seed=seed)
        self.probs = self.rng.uniform(*self.prob_range, size=self.n_arms)
        self.best_arm_prob = self.probs.max()

    def step(self, action):
        reward = self.rng.binomial(1, self.probs[action])
        regret = self.best_arm_prob - self.probs[action]  # self.best_arm_prob - reward

        return None, reward, regret
