import os
import datetime
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import torch

class SyntheticDatasets(Dataset):
    def __init__(self, n_epochs, n_agents, n_arms, rng):
        super().__init__()
        means = rng.uniform(low= 0.0, high= 1.0, size= (n_agents, n_arms))
        L_navie = np.array([means] * n_epochs, dtype=np.float32)
        noise = rng.normal(loc=0.0, scale=0.01, size=L_navie.shape).astype(np.float32)
        L_noisy = L_navie + noise
        self.data = L_noisy

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx]

    def cumloss_of_best_arm(self):
        true_loss = np.mean(self.data, axis= 1)
        cum_losses = np.cumsum(true_loss, axis= 0)
        best_arm = np.argmin(cum_losses[-1, ])
        return cum_losses[:, best_arm]

    def best_arm(self):
        true_loss = np.mean(self.data, axis=1)
        cum_losses = np.cumsum(true_loss, axis=0)
        best_arm = np.argmin(cum_losses[-1,])
        return best_arm

class MovieLensDatasets(Dataset):
    def __init__(self, n_epochs) -> None:
        self.T = n_epochs
        super().__init__()
        self.data_old = pd.read_pickle(
            os.path.join(
                os.path.dirname(__file__),
                '../MovieLens_loss.pkl'
            )
        )

        self.genres = ['Action', 'Adventure', 'Animation', 'Children', 'Comedy', 'Crime',
            'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'IMAX', 'Musical',
            'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western', '(no genres listed)'
        ]

        self.userID = []
        for index in range(len(self.data_old.index)):
            if self.data_old.index.values[index][1] not in self.userID:
                self.userID.append(self.data_old.index.values[index][1])

        self.N = len(self.userID)
        self.K = 20
        x = np.zeros([self.N, self.K])
        y = np.zeros([self.N, self.K])
        for index in range(len(self.data_old.index)):
            i = self.data_old.index.values[index][1]
            k = self.data_old.index.values[index][2]
            x[self.userID.index(i)][self.genres.index(k)] += 1

        self.data = np.zeros([self.T, self.N, self.K], dtype=np.float32)
        for index in range(len(self.data_old.index)):
            i = self.userID.index(self.data_old.index.values[index][1])
            k = self.genres.index(self.data_old.index.values[index][2])
            y[i][k] += 1
            l = self.data_old.values[index][0]
            start = int((y[i][k] - 1) * self.T // x[i][k])
            end = int(y[i][k] * self.T // x[i][k])
            for t in range(start, end):
                self.data[t][i][k] = l

    def __len__(self):
        return self.data.shape[0]

    def cumloss_of_best_arm(self):
        true_loss = np.mean(self.data, axis=1)
        cum_losses = np.cumsum(true_loss, axis=0)
        best_arm = np.argmin(cum_losses[-1,])
        return cum_losses[:,best_arm]

    def best_arm(self):
        true_loss = np.mean(self.data, axis=1)
        cum_losses = np.cumsum(true_loss, axis=0)
        best_arm = np.argmin(cum_losses[-1,])
        return best_arm

    def __getitem__(self, idx):
        return self.data[idx]