import numpy as np
import torch
import copy
from tqdm import tqdm
import os


class  NeuralTS:
    def __init__(self, nu, T, oracle, generator):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.lam = 1
        self.nu     = nu
        self.T          = T
        self.oracle     = oracle
        oracle.to(self.device)
        self.K = 20
        self.generator  = generator

        self.m = self.oracle.layers[0].weight.data.size(1)
        self.num_param = sum(p.numel() for p in self.oracle.parameters())
        self.U = self.lam * torch.diag(torch.ones(self.num_param).to(self.device))

        self.reward         = np.zeros((T,))
        self.optimal_reward = np.zeros((T,))

        self.t = 0
    
    def reset(self):
        self.t = 0
    
    def sample(self, x):
        a = torch.linspace(0,1,self.K).reshape(-1,1).to(self.device)
        pred = self.oracle(torch.tile(x,(self.K,1)),a).flatten()

        g_list = []
        sigma = torch.zeros(self.K).to(self.device)

        for k in range(self.K):
            self.oracle.zero_grad()
            vec = torch.zeros(self.K).to(self.device)
            vec[k] = 1
            pred.backward(gradient=vec, retain_graph=True)
            g = torch.cat([p.grad.detach().flatten() for p in self.oracle.parameters()])
            g_list.append(g)

            sigma[k] = self.nu * self.lam * (torch.sum(g * g / torch.diag(self.U)) / self.m).item()
        
        sampled_reward = torch.normal(pred, sigma)
        idx = torch.argmax(sampled_reward).item()
        g = g_list[idx]
        self.U = self.U + torch.outer(g,g)/self.m
        
        return a[idx].item()

    def run(self, rep, env, basedir):
        rewards = np.zeros((rep,self.T))
        optimal_rewards = np.zeros((rep,self.T))

        for r in range(rep):
            print(f'run {r}')

            env.reset()
            self.oracle.reset()
            self.reset()

            for t in tqdm(range(self.T)):

                # receive a context
                x = env.gen_context()

                # sample an action
                price = self.sample(x)

                # take the action and receive response of env
                realization, probability = env.act(x,price)

                # update the oracle and the pricing algorithm
                self.oracle.update(x, price, price*realization)

                # log data
                self.reward[t] = price*probability
                _, self.optimal_reward[t] = env.optimal_action(x)

            rewards[r] = copy.deepcopy(self.reward)
            optimal_rewards[r] = copy.deepcopy(self.optimal_reward)

        # save results
        if not os.path.exists(basedir):
            os.makedirs(basedir)
        np.save(basedir+'/reward.npy', rewards)
        np.save(basedir+'/optimal_reward.npy', optimal_rewards)

