import torch
from torch import nn, optim
import pytorch_lightning as pl
import numpy as np
from tqdm import tqdm
from sklearn import datasets, preprocessing
from scipy.optimize import minimize_scalar, minimize
from utils import *
from policies import *

probit = torch.distributions.normal.Normal(0., 1.).cdf


##############################
######################## Exponential Smoothing
#############################


class Smoothing(SmoothingPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r, n_samples = 1000):
        bsize = x.size(0)
        sigma = torch.exp(self.log_scale)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        scores_noised = scores + sigma * torch.randn(bsize, 1, n_samples).to(self.dev)
        dist_x = torch.softmax(scores_noised, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_c = ((self.normal_kl_div() + self.unc_kl1)/(2. * self.N)) ** 0.5
        kl_a = (self.normal_kl_div() + self.unc_kl2) / self.N
        
        lmbd = 2 * np.sqrt(kl_a.detach().item() * mm_mean.detach().item())
        
        first_part = risk_mean + bias_mean + kl_c 
        second_part = kl_a / lmbd + 0.5 * lmbd * mm_mean
        our_bound = first_part + second_part
        
        return our_bound    
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    
    
    
class HeuristicSmoothing(HeuristicSmoothingPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, lmbd1, lmbd2, lmbd3, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        self.lmbd1 = lmbd1
        self.lmbd2 = lmbd2
        self.lmbd3 = lmbd3
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r):#, n_samples = 1000
        bsize = x.size(0)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        dist_x = torch.softmax(scores, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_ = self.normal_kl_div()
            
        our_bound = risk_mean + self.lmbd1*kl_ + self.lmbd2*bias_mean + self.lmbd3*mm_mean
        
        return our_bound
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    


    
    
    
    
class GaussianSmoothing(GaussianSmoothingPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r):#, n_samples = 1000
        bsize = x.size(0)
        dist_x = self.policy(x, n_samples = 32).unsqueeze(-1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_c = ((self.normal_kl_div() + self.unc_kl1)/(2. * self.N)) ** 0.5
        kl_a = (self.normal_kl_div() + self.unc_kl2) / self.N
        
        lmbd = 2 * np.sqrt(kl_a.detach().item() * mm_mean.detach().item())
        
        first_part = risk_mean + kl_c + bias_mean
        second_part = kl_a / lmbd + 0.5 * lmbd * mm_mean
        our_bound = first_part + second_part
        
        return our_bound
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    
    
##############################
######################## Clipping
#############################


class Clipping(ClippingPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r, n_samples = 1000):
        bsize = x.size(0)
        sigma = torch.exp(self.log_scale)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        scores_noised = scores + sigma * torch.randn(bsize, 1, n_samples).to(self.dev)
        dist_x = torch.softmax(scores_noised, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_c = ((self.normal_kl_div() + self.unc_kl1)/(2. * self.N)) ** 0.5
        kl_a = (self.normal_kl_div() + self.unc_kl2) / self.N
        
        lmbd = 2 * np.sqrt(kl_a.detach().item() * mm_mean.detach().item())
        
        first_part = risk_mean + bias_mean + kl_c 
        second_part = kl_a / lmbd + 0.5 * lmbd * mm_mean
        our_bound = first_part + second_part
        
        return our_bound    
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    
    

class HeuristicClipping(HeuristicClippingPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, lmbd1, lmbd2, lmbd3, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        self.lmbd1 = lmbd1
        self.lmbd2 = lmbd2
        self.lmbd3 = lmbd3
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r):#, n_samples = 1000
        bsize = x.size(0)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        dist_x = torch.softmax(scores, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_ = self.normal_kl_div()
            
        our_bound = risk_mean + self.lmbd1*kl_ + self.lmbd2*bias_mean + self.lmbd3*mm_mean
        
        return our_bound
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss   
    
    
    
class GaussianClipping(GaussianClippingPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r):#, n_samples = 1000
        bsize = x.size(0)
        dist_x = self.policy(x, n_samples = 32).unsqueeze(-1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_c = ((self.normal_kl_div() + self.unc_kl1)/(2. * self.N)) ** 0.5
        kl_a = (self.normal_kl_div() + self.unc_kl2) / self.N
        
        lmbd = 2 * np.sqrt(kl_a.detach().item() * mm_mean.detach().item())
        
        first_part = risk_mean + kl_c + bias_mean
        second_part = kl_a / lmbd + 0.5 * lmbd * mm_mean
        our_bound = first_part + second_part
        
        return our_bound
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    

    

    

##############################
######################## Harmonic
#############################


class Harmonic(HarmonicPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r, n_samples = 1000):
        bsize = x.size(0)
        sigma = torch.exp(self.log_scale)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        scores_noised = scores + sigma * torch.randn(bsize, 1, n_samples).to(self.dev)
        dist_x = torch.softmax(scores_noised, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_c = ((self.normal_kl_div() + self.unc_kl1)/(2. * self.N)) ** 0.5
        kl_a = (self.normal_kl_div() + self.unc_kl2) / self.N
        
        lmbd = 2 * np.sqrt(kl_a.detach().item() * mm_mean.detach().item())
        
        first_part = risk_mean + bias_mean + kl_c 
        second_part = kl_a / lmbd + 0.5 * lmbd * mm_mean
        our_bound = first_part + second_part
        
        return our_bound    
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss
    
    
class HeuristicHarmonic(HeuristicHarmonicPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, lmbd1, lmbd2, lmbd3, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        self.lmbd1 = lmbd1
        self.lmbd2 = lmbd2
        self.lmbd3 = lmbd3
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r):#, n_samples = 1000
        bsize = x.size(0)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        dist_x = torch.softmax(scores, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_ = self.normal_kl_div()
            
        our_bound = risk_mean + self.lmbd1*kl_ + self.lmbd2*bias_mean + self.lmbd3*mm_mean
        
        return our_bound
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss

    
    

    
##############################
######################## Shrinkage
#############################


class Shrinkage(ShrinkagePolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r, n_samples = 1000):
        bsize = x.size(0)
        sigma = torch.exp(self.log_scale)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        scores_noised = scores + sigma * torch.randn(bsize, 1, n_samples).to(self.dev)
        dist_x = torch.softmax(scores_noised, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_c = ((self.normal_kl_div() + self.unc_kl1)/(2. * self.N)) ** 0.5
        kl_a = (self.normal_kl_div() + self.unc_kl2) / self.N
        
        lmbd = 2 * np.sqrt(kl_a.detach().item() * mm_mean.detach().item())
        
        first_part = risk_mean + bias_mean + kl_c 
        second_part = kl_a / lmbd + 0.5 * lmbd * mm_mean
        our_bound = first_part + second_part
        
        return our_bound    
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    



##############################
######################## Shrinkage
#############################


class HeuristicShrinkage(HeuristicShrinkagePolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, lmbd1, lmbd2, lmbd3, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        self.lmbd1 = lmbd1
        self.lmbd2 = lmbd2
        self.lmbd3 = lmbd3
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r):#, n_samples = 1000
        bsize = x.size(0)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        dist_x = torch.softmax(scores, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_ = self.normal_kl_div()
            
        our_bound = risk_mean + self.lmbd1*kl_ + self.lmbd2*bias_mean + self.lmbd3*mm_mean
        
        return our_bound
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    





##############################
######################## Implicit Exploration
#############################


class IX(IXPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r, n_samples = 1000):
        bsize = x.size(0)
        sigma = torch.exp(self.log_scale)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        scores_noised = scores + sigma * torch.randn(bsize, 1, n_samples).to(self.dev)
        dist_x = torch.softmax(scores_noised, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_c = ((self.normal_kl_div() + self.unc_kl1)/(2. * self.N)) ** 0.5
        kl_a = (self.normal_kl_div() + self.unc_kl2) / self.N
        
        lmbd = 2 * np.sqrt(kl_a.detach().item() * mm_mean.detach().item())
        
        first_part = risk_mean + bias_mean + kl_c 
        second_part = kl_a / lmbd + 0.5 * lmbd * mm_mean
        our_bound = first_part + second_part
        
        return our_bound    
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    
    
    
class HeuristicIX(HeuristicIXPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, lmbd1, lmbd2, lmbd3, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        self.lmbd1 = lmbd1
        self.lmbd2 = lmbd2
        self.lmbd3 = lmbd3
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r):
        bsize = x.size(0)
        scores = torch.matmul(x, self.q_mean.T).unsqueeze(-1)
        dist_x = torch.softmax(scores, dim = 1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_ = self.normal_kl_div()
            
        our_bound = risk_mean + self.lmbd1*kl_ + self.lmbd2*bias_mean + self.lmbd3*mm_mean
        
        return our_bound
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    
    
class GaussianIX(GaussianIXPolicy):
    
    def __init__(self, n_actions, context_dim, tau, N, diag = False, loc_weight=None, delta = 0.05, device = torch.device("cpu")):
        
        super().__init__(n_actions = n_actions, context_dim = context_dim, 
                         tau = tau, N = N, lmbd = 1., diag = diag, loc_weight = loc_weight, device = device)
        
        
        self.delta = delta
        self.unc_kl1 = np.log((4. * self.N ** 0.5)/self.delta)
        self.unc_kl2 = np.log(4/self.delta)
    
    def upper_bound(self, x, a, ps, r):#, n_samples = 1000
        bsize = x.size(0)
        dist_x = self.policy(x, n_samples = 32).unsqueeze(-1)
        dist_x_a = dist_x[np.arange(bsize), a, :]
        
        ps_ = ps[:, :, None]
        mm1 = self.compute_mean_second_moment(dist_x, ps_)
        mm2 = self.compute_empirical_second_moment(dist_x_a, a, ps_, r)
        mm_mean = torch.mean(mm1 + mm2, dim=-1)
        bias_mean = torch.mean(self.compute_mean_bias(dist_x, ps_), dim=-1)
        risk_mean = torch.mean(torch.mean(self.compute_risk(dist_x_a, a, ps_, r), dim=0), dim=-1)
        
        kl_c = ((self.normal_kl_div() + self.unc_kl1)/(2. * self.N)) ** 0.5
        kl_a = (self.normal_kl_div() + self.unc_kl2) / self.N
        
        lmbd = 2 * np.sqrt(kl_a.detach().item() * mm_mean.detach().item())
        
        first_part = risk_mean + kl_c + bias_mean
        second_part = kl_a / lmbd + 0.5 * lmbd * mm_mean
        our_bound = first_part + second_part
        
        return our_bound
        
    def training_step(self, train_batch, batch_idx):
        x, a, ps, r = train_batch
        loss = self.upper_bound(x, a, ps, r)
        return loss    
    
    
