from sklearn.datasets import fetch_openml
from sklearn.utils import shuffle
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import normalize
import numpy as np
import pandas as pd 

import torch
import torchvision
from torchvision import datasets, transforms
from sklearn.cluster import KMeans
from collections import defaultdict

        
        
class load_notmnist_mnist_2:
    def __init__(self):
        
        # Not mnist
        X = np.load('./not_mnist_data.npy', allow_pickle=True)
        y = np.load('./not_mnist_label.npy', allow_pickle=True)
        new_X = []
        for i in X:
            i = i.flatten()
            new_X.append(i)
        X = np.array(new_X)
        X[np.isnan(X)] = - 1
        X = normalize(X)
        self.X, self.y =shuffle(X, y)
        self.y_arm = OrdinalEncoder(
            dtype=np.int).fit_transform(self.y.reshape((-1, 1)))
        self.cursor = 0
        self.size = self.y.shape[0]
        self.n_arm = np.max(self.y_arm) + 1
        self.dim = self.X.shape[1] + 9
        self.act_dim = self.X.shape[1]
        
        #  mnist
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size,
                                      shuffle=True, num_workers=2)
        self.dataiter = iter(train_loader)
        self.num_user = 20

    def step(self):
        
        choice = np.random.choice(2, 1, replace=True)[0]
        if choice == 0:
            assert self.cursor < self.size
            X = np.zeros((self.n_arm, self.dim))
            for a in range(self.n_arm):
                X[a, a:a+
                    self.act_dim] = self.X[self.cursor]
            arm = self.y_arm[self.cursor][0]
            rwd = np.zeros((self.n_arm,))
            rwd[arm] = 1
            self.cursor += 1
            return arm,  X, rwd
        else:
            x, y = self.dataiter.next()
            d = x.numpy()[0]
            d = d.reshape(self.act_dim )
            target = y.item()
            X = np.zeros((self.n_arm, self.dim))
            for a in range(self.n_arm):
                X[a, a:a+
                    self.act_dim] = d
            rwd = np.zeros(self.n_arm)
            #print(target)
            rwd[target] = 1
            return 10+target, X, rwd  

    

class load_cifar10_1d_10:
    def __init__(self, is_shuffle=True):
        # Fetch data
        batch_size = 1
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
        self.dataiter = iter(trainloader)
        self.n_arm = 10
        self.dim = 3072+90
        self.num_user = 10

        
    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0]
        d = d.reshape(3072)
        target = y.item()
        X_n = []
        for i in range(self.n_arm):
            front = np.zeros((10*i))
            back = np.zeros((10*(9 - i)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        #print(target)
        rwd[target] = 1
        #print(rwd)
        #print(X_n.shape)
        return target, X_n, rwd
    
    
    
class load_cifar10_1d:
    def __init__(self, is_shuffle=True):
        # Fetch data
        batch_size = 1
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
        self.dataiter = iter(trainloader)
        self.n_arm = 3
        self.num_zeros = 100
        self.dim = 3072 + self.num_zeros*(self.n_arm-1)
        self.num_user = 10

        
    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0]
        d = d.reshape(3072)
        target = int(y.item()/4.0)
        X_n = []
        num_zeros = 100
        for i in range(3):
            front = np.zeros((self.num_zeros*i))
            back = np.zeros((self.num_zeros*(2 - i)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        rwd[target] = 1
        return y.item(), X_n, rwd
    



    
   
from scipy.stats import norm
class load_movielen_new:
    def __init__(self):
        # Fetch data
        self.m = np.load("../movielens/movie_2000users_10000items_entry.npy")
        self.U = np.load("../movielens/movie_2000users_10000items_features.npy")
        self.I = np.load("../movielens/movie_10000items_2000users_features.npy")
        kmeans = KMeans(n_clusters=50, random_state=0).fit(self.U)
        self.groups = kmeans.labels_
        self.n_arm = 10
        self.dim = 10
        self.num_user = 50
        self.pos_index = defaultdict(list)
        self.neg_index = defaultdict(list)
        for i in self.m:
            if i[2] ==1:
                self.pos_index[self.groups[i[0]]].append((i[0], i[1]))
            else:
                self.neg_index[self.groups[i[0]]].append((i[0], i[1]))   


    def step(self):    
        u = np.random.choice(range(2000))
        g = self.groups[u]
        arm = np.random.choice(range(10))
        #print(pos_index.shape)
        p_d = len(self.pos_index[g])
        n_d = len(self.neg_index[g])
        pos = np.array(self.pos_index[g])[np.random.choice(range(p_d), 9, replace=True)]
        neg = np.array(self.neg_index[g])[np.random.choice(range(n_d), replace=True)]
        X_ind = np.concatenate((pos[:arm], [neg], pos[arm:]), axis=0)
        X = []
        for ind in X_ind:
            #X.append(np.sqrt(np.multiply(self.I[ind], u_fea)))
            X.append(self.I[ind[1]])
        rwd = np.zeros(self.n_arm)
        rwd[arm] = 1
        contexts = norm.pdf(np.array(X), loc=0, scale=0.5)
        return g, contexts, rwd
    
    
class load_yelp_new:
    def __init__(self):
        # Fetch data
        self.m = np.load("../Yelp/yelp_2000users_10000items_entry.npy")
        self.U = np.load("../Yelp/yelp_2000users_10000items_features.npy")
        self.I = np.load("../Yelp/yelp_10000items_2000users_features.npy")
        kmeans = KMeans(n_clusters=50, random_state=0).fit(self.U)
        self.groups = kmeans.labels_
        self.n_arm = 10
        self.dim = 10
        self.num_user = 50
        self.pos_index = defaultdict(list)
        self.neg_index = defaultdict(list)
        for i in self.m:
            if i[2] ==1:
                self.pos_index[self.groups[i[0]]].append((i[0], i[1]))
            else:
                self.neg_index[self.groups[i[0]]].append((i[0], i[1]))   
   

    def step(self):    
        u = np.random.choice(range(2000))
        g = self.groups[u]
        arm = np.random.choice(range(10))
        #print(pos_index.shape)
        p_d = len(self.pos_index[g])
        n_d = len(self.neg_index[g])
        pos = np.array(self.pos_index[g])[np.random.choice(range(p_d), 9, replace=True)]
        neg = np.array(self.neg_index[g])[np.random.choice(range(n_d), replace=True)]
        X_ind = np.concatenate((pos[:arm], [neg], pos[arm:]), axis=0)
        X = []
        for ind in X_ind:
            #X.append(np.sqrt(np.multiply(self.I[ind], u_fea)))
            X.append(self.I[ind[1]])
        rwd = np.zeros(self.n_arm)
        rwd[arm] = 1
        contexts = norm.pdf(np.array(X), loc=0, scale=0.5)
        return g, contexts, rwd

    

class Bandit_multi:
    def __init__(self, name):
        # Fetch data
        if name == 'covertype':
            X, y = fetch_openml('covertype', version=3, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X,y)
            # class: 1-7
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'MagicTelescope':
            X, y = fetch_openml('MagicTelescope', version=1, return_X_y=True)
            # class: h, g
            # avoid nan, set nan as -1
            # print(X,y)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'shuttle':
            X, y = fetch_openml('shuttle', version=1, return_X_y=True)
            # avoid nan, set nan as -1
            # print(X,y)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'adult':
            X, y = fetch_openml('adult', version=2, return_X_y=True)
            X = pd.get_dummies(X)
            # avoid nan, set nan as -1
            # print(X,y)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'mushroom':
            X, y = fetch_openml('mushroom', version=1, return_X_y=True)
            # print(X,y,X.info())
            X = pd.get_dummies(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'fashion':
            X, y = fetch_openml('Fashion-MNIST', version=1, return_X_y=True)
            # print(X,y,X.info())
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'notmnist':
            X = np.load('./nomnist/imagedat.npy', allow_pickle=True)
            y = np.load('./nomnist/labeldata.npy', allow_pickle=True)
            new_X = []
            for i in X:
                i = i.flatten()
                new_X.append(i)
            X = np.array(new_X)
            print('notmnist', X.shape)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        else:
            raise RuntimeError('Dataset does not exist')
        # Shuffle data
        self.X, self.y = shuffle(X, y)
        # generate one_hot coding:
        self.y_arm = np.array(self.y.values).astype(np.int)
        # cursor and other variables
        self.cursor = 0
        self.size = self.y.shape[0]
        self.n_arm = int(np.max(self.y_arm)/2+1)
        self.dim = self.X.shape[1]  + self.n_arm
        self.act_dim = self.X.shape[1]
        self.num_user = np.max(self.y_arm)+1
        print(self.dim)
        print(self.n_arm)

    def step(self):
        if self.cursor > (len(self.X)-1):
            self.cursor = 0
    
        x = self.X[self.cursor]
        y = self.y_arm[self.cursor]
        target = int(y.item()/2.0)
        X_n = []
        for i in range(self.n_arm):
            front = np.zeros((1*i))
            back = np.zeros((1*(self.n_arm - i)))
            new_d = np.concatenate((front, x, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        rwd[target] = 1
        self.cursor += 1
        return y.item(), X_n, rwd

    
class load_emnist_letter_1d:
    def __init__(self, is_shuffle=True):
        # Fetch data
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        trainset = torchvision.datasets.EMNIST(root='./data', split = "letters", train=True,
                                        download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
        self.dataiter = iter(trainloader)

        self.n_arm = 26
        self.num_zeros = 10
        self.num_class = 26
        self.num_user = 26
        self.dim = 28*28 + self.num_zeros*(self.num_class - 1)

        
        
    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0][0].reshape(28*28)
        target = y.item()-1
        X_n = []
        for i in range(self.n_arm):
            front = np.zeros((self.num_zeros*i))
            back = np.zeros((self.num_zeros*(self.num_class - i-1)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        rwd[target] = 1
        return target, X_n, rwd
   
        