from sklearn.datasets import fetch_openml
from sklearn.utils import shuffle
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import normalize
import numpy as np
import pandas as pd 

import torch
import torchvision
from torchvision import datasets, transforms
# from sklearn.cluster import KMeans
from collections import defaultdict


class load_mnist_1d:
    def __init__(self):
        # Fetch data
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size,
                                      shuffle=True, num_workers=2)
        self.dataiter = iter(train_loader)
        self.n_arm = 10
        self.dim = 7840
 
    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0]
        d = d.reshape(784)
        target = y.item()
        X_n = []
        for i in range(10):
            front = np.zeros((784*i))
            back = np.zeros((784*(9 - i)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        #print(target)
        rwd[target] = 1
        return X_n, rwd  


class load_mnist_adv:
    def __init__(self):
        # Fetch data
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset = datasets.MNIST('./data', train=True, download=True,
                    transform=transform)
        self.dataiter = []
        for i in range(10):

          label = list(((dataset.train_labels == i).nonzero()).numpy().flatten())
          trainset = torch.utils.data.Subset(dataset, label)

          trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                      shuffle=True, num_workers=2)

          self.dataiter.append(iter(trainloader))

        self.n_arm = 10
        self.dim = 7840
 
    def step(self,i):
        if i != -1:
            prob = np.full(10,0.3/9)
            prob[i] = 0.7
        else:
            prob = np.full(10,0.1)
            
        j = np.random.choice(np.arange(0, 10), p=prob)
        #print(j)
        
        x, y = self.dataiter[j].next()
        d = x.numpy()[0]
        d = d.reshape(784)
        target = y.item()
        X_n = []
        for i in range(10):
            front = np.zeros((784*i))
            back = np.zeros((784*(9 - i)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        #print(target)
        rwd[target] = 1
        #print(rwd)
        return X_n, rwd, j  

    

class load_yelp:
    def __init__(self):
        # Fetch data
        self.m = np.load("./data/yelp_2000users_10000items_entry.npy")
        self.U = np.load("./data/yelp_2000users_10000items_features.npy")
        self.I = np.load("./data/yelp_10000items_2000users_features.npy")
        self.n_arm = 10
        self.dim = 20
        self.pos_index = []
        self.neg_index = []
        for i in self.m:
            if i[2] ==1:
                self.pos_index.append((i[0], i[1]))
            else:
                self.neg_index.append((i[0], i[1]))   
            
        self.p_d = len(self.pos_index)
        self.n_d = len(self.neg_index)
        print(self.p_d, self.n_d)
        self.pos_index = np.array(self.pos_index)
        self.neg_index = np.array(self.neg_index)


    def step(self):        
        arm = np.random.choice(range(10))
        #print(pos_index.shape)
        pos = self.pos_index[np.random.choice(range(self.p_d), 9, replace=False)]
        neg = self.neg_index[np.random.choice(range(self.n_d), replace=False)]
        X_ind = np.concatenate((pos[:arm], [neg], pos[arm:]), axis=0)
        X = []
        for ind in X_ind:
            #X.append(np.sqrt(np.multiply(self.I[ind], u_fea)))
            X.append(np.concatenate((self.U[ind[0]], self.I[ind[1]])))
        rwd = np.zeros(self.n_arm)
        rwd[arm] = 1
        return np.array(X), rwd

    

class load_movielen:
    def __init__(self):
        # Fetch data
        self.m = np.load("./data/movie_2000users_10000items_entry.npy")
        self.U = np.load("./data/movie_2000users_10000items_features.npy")
        self.I = np.load("./data/movie_10000items_2000users_features.npy")
        self.n_arm = 10
        self.dim = 20
        self.pos_index = []
        self.neg_index = []
        for i in self.m:
            if i[2] ==1:
                self.pos_index.append((i[0], i[1]))
            else:
                self.neg_index.append((i[0], i[1]))   
            
        self.p_d = len(self.pos_index)
        self.n_d = len(self.neg_index)
        print(self.p_d, self.n_d)
        self.pos_index = np.array(self.pos_index)
        self.neg_index = np.array(self.neg_index)


    def step(self):        
        arm = np.random.choice(range(10))
        #print(pos_index.shape)
        pos = self.pos_index[np.random.choice(range(self.p_d), 9, replace=False)]
        neg = self.neg_index[np.random.choice(range(self.n_d), replace=False)]
        X_ind = np.concatenate((pos[:arm], [neg], pos[arm:]), axis=0)
        X = []
        for ind in X_ind:
            #X.append(np.sqrt(np.multiply(self.I[ind], u_fea)))
            X.append(np.concatenate((self.U[ind[0]], self.I[ind[1]])))
        rwd = np.zeros(self.n_arm)
        rwd[arm] = 1
        return np.array(X), rwd

class load_notmnist_mnist_2:
    def __init__(self):       
        #  mnist
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset1, batch_size=batch_size,
                                      shuffle=True, num_workers=2)
        self.dataiter = iter(train_loader)
        self.n_arm = np.max(self.y_arm) + 1
        self.dim = self.X.shape[1] + 9

    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0]
        d = d.reshape(self.act_dim )
        target = y.item()
        X = np.zeros((self.n_arm, self.dim))
        for a in range(self.n_arm):
            X[a, a:a+
                self.act_dim] = d
        rwd = np.zeros(self.n_arm)
        #print(target)
        rwd[target] = 1
        return X, rwd  


class Bandit_multi:
    def __init__(self, name):
        # Fetch data
        if name == 'covertype':
            X, y = fetch_openml('covertype', version=3, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X,y)
            # class: 1-7
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            #X = X.to_numpy()
            X = normalize(X)
        elif name == 'MagicTelescope':
            X, y = fetch_openml('MagicTelescope', version=1, return_X_y=True)
            # class: h, g
            # avoid nan, set nan as -1
            # print(X,y)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'shuttle':
            X, y = fetch_openml('shuttle', version=1, return_X_y=True)
            # avoid nan, set nan as -1
            # print(X,y)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'adult':
            X, y = fetch_openml('adult', version=2, return_X_y=True)
            
            X = pd.get_dummies(X)
            # avoid nan, set nan as -1
            # print(X,y)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'mushroom':
            X, y = fetch_openml('mushroom', version=1, return_X_y=True)
            # print(X,y,X.info())
            X = pd.get_dummies(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'fashion':
            X, y = fetch_openml('Fashion-MNIST', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X,y,X.info())
            # avoid nan, set nan as -1
            X[np.isnan(X)] = - 1
            X = normalize(X)
        elif name == 'nursery':
            X, y = fetch_openml('nursery', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X)
            # print(y)
            X[np.isnan(X)] = - 1
            X = normalize(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
        elif name == 'Plants':
            X, y = fetch_openml('nursery', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            # print(X)
            # print(y)
            X[np.isnan(X)] = - 1
            X = normalize(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
            
        elif name == 'leaf':
            X, y = fetch_openml('leaf', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            X[np.isnan(X)] = - 1
            X = normalize(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
        elif name == 'eucalyptus':
            X, y = fetch_openml('eucalyptus', version=1, return_X_y=True)
            X = pd.get_dummies(X)
            X[np.isnan(X)] = - 1
            X = normalize(X)
            unique_values = set(y.values)
            label_map = {value:i+1 for i,value in enumerate(unique_values)}
            y = y.map(label_map)
        else:
            raise RuntimeError('Dataset does not exist')
        # Shuffle data
        self.X, self.y = shuffle(X, y)
        # generate one_hot coding:
        self.y_arm = np.array(self.y.values).astype(np.int64)
        if name != 'fashion':
          self.y_arm = self.y_arm - 1
        # cursor and other variables
        self.cursor = 0
        self.size = self.y.shape[0]
        self.n_arm = int(np.max(self.y_arm)+1)
        #self.n_arm = int(np.max(self.y_arm)/2+1)
        self.baseline = np.random.randint(self.n_arm)
        print(np.unique(self.y_arm),self.n_arm)
        print(self.X.shape[1])
        self.dim = self.X.shape[1]  + self.n_arm
        #self.dim = self.X.shape[1]  * self.n_arm
        self.act_dim = self.X.shape[1]
        self.num_user = np.max(self.y_arm)+1
        # print(self.dim)
        # print(self.n_arm)
        self.input_ = [self.X[self.y_arm == i] for i in range(self.n_arm)]
        # for i in range(5):
        #   print(self.X[i], self.y_arm[i])
        # for i in range(self.n_arm):
        #   self.input_.append(self.X[self.y_arm == i])
        #print(self.input_[0][0])
        

    def step(self, i = -1):
        if self.cursor > (len(self.X)-1):
            self.cursor = 0
        
        if i != -1:
            if self.n_arm == 2:
              p = 0.6
            else:
              p = 0.4
            prob = np.full(self.n_arm,(1-p)/(self.n_arm - 1))
            prob[i] = p
        else:
            prob = np.full(self.n_arm,1/self.n_arm)
            
        j = np.random.choice(np.arange(0, self.n_arm), p=prob)
        #print(self.input_)
        x = self.input_[j][0]
        #print(x,j)
        np.roll(self.input_[j],-1)
        y = j
        target = int(y)
        X_n = []
        for i in range(self.n_arm):

            ##########################
            # front = np.zeros((self.X.shape[1]*i))
            # back = np.zeros(self.X.shape[1]*(self.n_arm-1-i))
            # new_d = np.concatenate((front, x, back), axis=0)
            ##########################
            front = np.zeros((1*i))
            back = np.zeros((1*(self.n_arm - i)))
            new_d = np.concatenate((front, x, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.ones(self.n_arm)
        rwd[target] = 0.01
        self.cursor += 1
        return X_n, rwd, rwd, self.baseline


class load_emnist_letter_1d:
    def __init__(self, is_shuffle=True):
        # Fetch data
        batch_size = 1
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        trainset = torchvision.datasets.EMNIST(root='./data', split = "letters", train=True,
                                        download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
        self.dataiter = iter(trainloader)

        self.n_arm = 26
        self.num_zeros = 10
        self.num_class = 26
        self.num_user = 26
        self.dim = 28*28 + self.num_zeros*(self.num_class - 1)

        
        
    def step(self):
        x, y = self.dataiter.next()
        d = x.numpy()[0][0].reshape(28*28)
        target = y.item()-1
        X_n = []
        for i in range(self.n_arm):
            front = np.zeros((self.num_zeros*i))
            back = np.zeros((self.num_zeros*(self.num_class - i-1)))
            new_d = np.concatenate((front,  d, back), axis=0)
            X_n.append(new_d)
        X_n = np.array(X_n)    
        rwd = np.zeros(self.n_arm)
        rwd[target] = 1
        return X_n, rwd

class synthetic:
    def __init__(self, name):
        self.name = name
        self.n_arm = 4
        #self.x = np.zeros(self.K)
        self.dim = 20
        self.a = np.random.randn(self.dim, 1)
        self.A = np.random.normal(0,1,(self.dim,self.dim))
        self.baseline = np.random.randint(self.n_arm)
        #self.reward = np.zeros(self.K)
    def step(self):
        x = [0]*self.n_arm
        for i in range(self.n_arm):
            x[i] = np.random.randn(1, self.dim)[0]
            #print(x[i],np.linalg.norm(x[i], axis=0))
            x[i] /= np.linalg.norm(x[i], axis=0)
            #print(x[i])
        rwd = np.zeros(self.n_arm)
        true_rwd = np.zeros(self.n_arm)
        if self.name == 'cos':
            for i in range(self.n_arm):
                true_rwd[i] = np.cos(3*np.dot(x[i],self.a))
                rwd[i] = true_rwd[i] + np.random.normal(0, 0.1)
        elif self.name == 'square':
            for i in range(self.n_arm):
                true_rwd[i] = 10*(np.dot(x[i],self.a))**2
                rwd[i] = true_rwd[i] + np.random.normal(0, 0.1)
        else:
            AtA = np.matmul(np.transpose(self.A),self.A)
            for i in range(self.n_arm):
                true_rwd[i] = (np.dot(x[i],np.matmul(AtA,x[i])))
                rwd[i] = true_rwd[i] + np.random.normal(0, 0.1)

        X_n = np.array([x[i] for i in range(self.n_arm)])
        return X_n, rwd, true_rwd, self.baseline

    

    

