from util import *
from time import time
from tqdm import tqdm
import numpy as np
import torch

class NeuralBandit:
    def __init__(self, name, shuffle, seed, mode="MLP", device='cuda'):
        X, y = import_data(name)
        self.X, self.y = process_features(X, y, is_shuffle=shuffle, seed=seed)
        self.X = torch.tensor(self.X, dtype=torch.float32, device=device)
        self.y = torch.tensor(self.y, dtype=torch.long, device=device)
        self.num_data = y.shape[0]
        self.K = torch.max(self.y) + 1
        self.round = 0
        self.data_d = self.X.shape[1]
        self.device = device
        self.d = self.X.shape[1] * self.K # [d,0,...,0] ~ [0,...,0,d]
        self.mode = mode

    def step(self):
        if self.mode =="MLP":
            return self.step_MLP()
        elif self.mode == "CNN":    
            return self.step_CNN()

    def step_MLP(self):
        # assert self.round < self.num_data # modulo
        X = torch.zeros((self.K, self.d), device=self.device)
        for i in range(self.K):
            X[i, i*self.data_d:(i+1)*self.data_d] = self.X[self.round] # objective: each round, give data -> agent choose the label(arm)
        ans = self.y[self.round][0]
        rwd = torch.zeros((self.K,), device=self.device)
        rwd[ans] = 1
        self.round += 1
        self.round %= self.num_data
        return X, rwd

    def step_CNN(self):
        x = self.X[self.round]  # shape: (784,) for MNIST
        y = self.y[self.round].item()
        self.round = (self.round + 1) % self.num_data

        # reshape to image shape
        x_image = x.view(1, 28, 28)  # [1, 28, 28]

        # make 10 arms: each arm is [1, 28, 28] placed at unique position
        X = torch.zeros((self.K, self.K, 28, 28), device=self.device)
        for i in range(self.K):
            X[i, i] = x_image

        rwd = torch.zeros((self.K,), device=self.device)
        rwd[y] = 1
        return X, rwd
    
    def reset(self):
        self.round = 0


def evaluate_one(Alg, params, env, T, device='cuda'):
    alg = Alg(env, T, params)

    regret = torch.zeros(T, device=device)
    for t in tqdm(range(T), ncols=100, leave=False):
        arms, rewards = env.step()
        alg.get_arms(arms)
        arm = alg.get_arm(t)
        reward = rewards[arm]
        alg.update(t, arms[arm], reward)
        regret[t] = torch.max(rewards) - reward

    return regret.cpu().numpy(), alg


def evaluate(Alg, params, envs, T=10000, printout=True, device='cuda'):
    if printout:
        print("Evaluating %s" % Alg.print(), end=" ")

    start = time()
    num_episode = len(envs)
    regret = torch.zeros((T, num_episode), device=device)
    alg = [None] * num_episode

    for i, env in tqdm(enumerate(envs), desc=f"{Alg.print()}", ncols=100, leave=True, unit="iter"):
        output = evaluate_one(Alg, params, env, T, device)
        regret[:, i] = torch.tensor(output[0], device=device)
        alg[i] = output[1]

    if printout:
        total_regret = regret.sum(dim=0).cpu().numpy()
        print("Reward: %.2f +/- %.2f (median: %.2f, max: %.2f)" % 
              (total_regret.mean(), total_regret.std() / (num_episode**0.5),
               np.median(total_regret), total_regret.max()))
        print(" %.1f seconds" % (time() - start))
        print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

    return regret.cpu().numpy(), alg