import numpy as np
import torch
import torch.optim as optim
import tqdm


class NeuralUCBDiag:
    
    """
    Adapted from https://github.com/uclaml/NeuralUCB/blob/master/learner_diag.py
    """
    
    def __init__(self, dim, func, mean_func=None, lamdba=1, nu=1, lr=0.1, train_max_steps=1000, train_min_loss=1e-6, device='cuda'):
        self.device = device
        # ensure that if x is 1D tensor, then func(x) and mean_func(x) outputs scalar (not 1D tensor)
        self.func = func.float().to(self.device)  
        self.mean_func = None if (mean_func is None) else mean_func.float().to(self.device)
        self.context_list = []
        self.reward = []
        self.lamdba = lamdba
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)
        self.U = lamdba * torch.ones((self.total_param,)).float().to(self.device)
        self.nu = nu
        self.lr = lr
        self.train_max_steps = train_max_steps
        self.train_min_loss = train_min_loss
        self.dim = dim
        
        def forward_call(x):
            if self.mean_func is None:
                return self.func(x)
            else:
                m = self.mean_func(x)
                a = self.func(x)
                return m + a
            
        l = len(forward_call(torch.zeros(size=(dim,)).float().to(self.device)  ).shape)
        assert l == 0, l
        self.forward_call = forward_call
    
    def generate_acq_func(self):
        def acq(tensor):
            params_ = dict(self.func.named_parameters())
            f = lambda params, inputs: torch.func.functional_call(self.func, params, (inputs,))
            jacobians = torch.func.jacrev(f)(params_, tensor)
            g = torch.concatenate([j.flatten() for j in jacobians.values()])
            sigma2 = self.lamdba * self.nu * g * g / self.U
            sigma = torch.sqrt(torch.sum(sigma2))
            mu = self.forward_call(tensor)
            assert mu.shape == sigma.shape
            sample_r = mu + sigma
            return sample_r, mu, sigma
        return acq
    
    def _update_U(self, x):
        self.func.zero_grad()
        if self.mean_func is not None:
            self.mean_func.zero_grad()
        mu = self.forward_call(x.float().to(self.device))
        mu.backward(retain_graph=True)
        g = torch.cat([p.grad.flatten().detach() for p in self.func.parameters()])
        assert g.shape == self.U.shape
        self.U += g * g
        
    def manual_select(self, x):
        self._update_U(x)
        return x, self.generate_acq_func()(x)[0].tolist()
        
    def select_from_samples_sequentially(self, contexts):
        acq = self.generate_acq_func()
        scores = [acq(x.float().to(self.device))[0] for x in tqdm.tqdm(contexts)]
        best_idx = np.argmax(scores)
        x = contexts[best_idx]
        self._update_U(x)
        return x, scores[best_idx]
    
    def select_from_samples_in_parallel(self, contexts):
        acq = self.generate_acq_func()
        scores, _, _ = torch.func.vmap(acq)(contexts.float().to(self.device))
        best_idx = torch.argmax(scores)
        x = contexts[best_idx]
        self._update_U(x)
        return x, scores[best_idx]
    
    def select_using_gd(self, init_c, max_norm=float('inf'), lr=None):
        x = init_c.float().to(self.device)
        # optimizer = torch.optim.SGD([x], lr=self.lr)
        acq = self.generate_acq_func()
        print(acq(x))
        grad = torch.func.grad(lambda x_: acq(x_)[0])
        lr_ = self.lr if lr is None else lr
        for _ in range(self.train_max_steps):
            # optimizer.zero_grad()
            # loss = 0. - acq(x)[0]
            # loss.backward()
            # optimizer.step()
            x = x + lr_ * grad(x)
            norm = x.norm(p=2)
            if norm > max_norm:
                x = x * (max_norm / norm)
        self._update_U(x)
        return x, acq(x)[0].tolist()
    
    def get_mu_and_sigma(self, contexts):
        acq = self.generate_acq_func()
        _, mu, sigma = torch.func.vmap(acq)(contexts.float().to(self.device))
        return mu, sigma

    def train(self, context, reward):
        assert len(context.shape) == 1
        self.context_list.append(context.reshape(1, -1).float())
        self.reward.append(reward)
        
        for i in [0, 1]:
            
            if i == 0 and self.mean_func is None:
                # no mean_func so skip to just training full network
                continue
            elif i == 0:
                print('Training mean function only')
                self.mean_func.train()
                params = self.mean_func.parameters()
                f = self.mean_func.forward
                wd = 0.
            else:
                print('Training NN only')
                self.func.train()
                params = self.func.parameters()
                f = self.forward_call
                wd = self.lamdba
            
            optimizer = optim.SGD(params, lr=self.lr, weight_decay=wd)
            train_x = torch.concatenate(self.context_list, dim=0).float().to(self.device)
            train_y = torch.tensor(self.reward).float().to(self.device)
            N = self.train_max_steps
            tot_loss = 0.
            for cnt in range(1, N+1):
                optimizer.zero_grad()
                output = f(train_x)
                loss = ((output.flatten() - train_y.flatten())**2).mean()
                loss.backward()
                optimizer.step()
                tot_loss += loss.item()
                if loss.item() < self.train_min_loss:
                    break
            print(f'Train to step {cnt}, loss={loss.item()}, batch_loss={tot_loss / cnt}')
