
import torch
import numpy

def KL(a, b, state_visitation):
    a=torch.clip(a,0.0001,1.0)
    b=torch.clip(b,0.0001,1.0)
    cc=torch.sum(a * torch.log(a) - a * torch.log(b) ,dim=1).reshape((16,1))
    return torch.sum(state_visitation*cc) 

# d=2
training_meta_theta_2=torch.tensor([[-1.3453e+00,  8.4050e-01,  1.8522e+00, -1.3474e+00],
        [-1.2883e+00,  2.5834e+00, -4.1027e-01, -8.8482e-01],
        [-3.8953e-01,  1.7065e+00, -9.8510e-01, -3.3190e-01],
        [-2.3574e-02,  1.6721e-01, -7.3353e-02, -7.0283e-02],
        [-7.2375e-01, -9.0497e-01,  2.6835e+00, -1.0547e+00],
        [-1.2105e+00, -8.5522e-01,  3.2741e+00, -1.2084e+00],
        [-1.4597e+00,  1.4396e+00,  1.5939e+00, -1.5737e+00],
        [-7.2481e-01,  2.3825e+00, -7.2065e-01, -9.3709e-01],
        [-1.9936e-01,  1.6557e-01,  1.0272e-01, -6.8928e-02],
        [-5.6761e-01, -7.8136e-02,  1.1161e+00, -4.7039e-01],
        [-1.3467e+00,  2.7498e+00, -1.0227e-01, -1.3007e+00],
        [-1.1379e+00,  3.1249e+00, -8.2362e-01, -1.1634e+00],
        [-1.1827e-01, -1.1756e-01,  3.6775e-01, -1.3191e-01],
        [-4.6190e-01, -2.5093e-01,  1.2050e+00, -4.9215e-01],
        [-1.1204e+00, -7.6465e-01,  2.9529e+00, -1.0678e+00],
        [ 3.1300e-07,  2.1723e-05, -2.4879e-05,  1.3542e-07]],
       requires_grad=False)


training_meta_theta_no_hole_2=torch.tensor([[-2.6676,  2.5738, -0.2887, -2.6750],
        [-1.9896,  2.7592, -1.1140, -2.1067],
        [-1.6433,  2.0680, -1.2569, -1.8562],
        [-1.4591,  1.5675, -1.6353, -1.6191],
        [-2.7840, -1.8724,  3.6894, -2.5779],
        [-2.5824, -0.3216,  3.7105, -2.5521],
        [-2.5513,  0.8343,  3.8499, -2.5250],
        [-2.8509,  2.8984, -3.2379, -2.7482],
        [-2.2198, -1.9536,  2.2616, -1.9592],
        [-2.0926,  0.9336,  2.3797, -2.1446],
        [-2.1235,  0.8650,  2.8642, -2.1439],
        [-2.7618,  2.8367, -3.1221, -2.7379],
        [-1.5794, -1.5709,  1.5708, -1.5178],
        [-1.9727, -2.3238,  2.0810, -2.0525],
        [-2.1341, -2.4092,  2.2148, -2.1739],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], requires_grad=False)

# d=1

training_meta_theta_1=torch.tensor([[-1.7254,  2.4021,  0.8041, -1.7259],
        [-1.2183,  2.8369, -0.5626, -1.6981],
        [-0.1186,  2.9944, -1.2270, -0.8516],
        [ 0.4858,  0.0343, -0.2044, -0.2096],
        [-2.1482, -0.4993,  2.3559, -1.8796],
        [-1.7470, -1.1979,  3.2517, -1.8906],
        [-2.6758,  0.7135,  0.9822, -3.1263],
        [-2.8828,  3.2549, -3.4645, -2.7275],
        [-1.6217,  1.6527, -0.4064, -0.3221],
        [-1.3422, -0.4186,  1.8679, -1.1695],
        [-2.6960,  3.7982, -1.6422, -2.7242],
        [-3.1232,  3.4268, -3.7953, -3.0336],
        [-1.8402, -1.8146,  1.5786, -0.9834],
        [-2.2327, -2.7928,  2.5982, -2.2197],
        [-2.8831, -3.6055,  3.2536, -2.9632],
        [ 0.0527,  0.0549,  0.0556,  0.0596]], requires_grad=False)


training_meta_theta_no_hole_1=torch.tensor([[-2.2226,  2.7543, -0.5609, -2.2160],
        [-1.7479,  2.6798, -1.2242, -1.8018],
        [-1.4730,  2.1132, -1.0276, -1.6941],
        [-1.3414,  1.7173, -1.5965, -1.5958],
        [-2.5484, -1.3991,  3.5703, -2.4142],
        [-2.4656,  0.3753,  3.6660, -2.3586],
        [-2.5463,  1.5573,  3.7354, -2.5226],
        [-2.7225,  3.0002, -3.0842, -2.6988],
        [-2.2126, -2.0037,  2.4275, -1.9208],
        [-2.1521, -0.1312,  2.7542, -2.1845],
        [-2.2049,  1.1670,  3.0820, -2.2031],
        [-2.6433,  2.8864, -3.0347, -2.5518],
        [-1.4972, -1.4934,  1.7256, -1.4578],
        [-1.8945, -2.1776,  2.1899, -1.9805],
        [-2.0644, -2.3894,  2.3555, -2.1225],
        [ 0.1000,  0.1000,  0.1000,  0.1000]], requires_grad=False)

