import numpy as np
import torch

from f012 import *

T = 200

constant_1 = torch.tensor([10.])
constant_vec2 = torch.tensor([0.1, 0.5])
norm0 = torch.sqrt(constant_1 * f0(X=constant_vec2.unsqueeze(0)))

constant_1 = torch.tensor([3.])
constant_vec2 = torch.tensor([0.3, 0.7])
norm1 = torch.sqrt(constant_1 * f10(X=constant_vec2.unsqueeze(0)))

constant_1 = torch.tensor([1.])
constant_vec2 = torch.tensor([0.1, 0.3])
norm2 = torch.sqrt(constant_1 * f21(X=constant_vec2.unsqueeze(0)))

n0_cf = norm0**(2/(1+2))*1.**(-2/(2*1+2*2))
n1_cf = norm1**(2/(1+2))*3.**(-2/(2*1+2*2))
n2_cf = norm2**(2/(1+2))*5.**(-2/(2*1+2*2))
R_cf = n0_cf*1. + n1_cf*3. + n2_cf*5.

N0_cf = T *n0_cf/R_cf
N1_cf = T *n1_cf/R_cf
N2_cf = T *n2_cf/R_cf
N_cf = np.array([N0_cf.item(), N1_cf.item(), N2_cf.item()])
np.save("nsamples_cf.npy", N_cf)


d = 2
n = 100


X0 = torch.distributions.Uniform(0, 1).sample((n, d))
X1 = torch.distributions.Uniform(0, 1).sample((n, d))
X2 = torch.distributions.Uniform(0, 1).sample((n, d))
X_mc = np.vstack((X0, X1, X2))

Y0 = f0(X0)
Y10 = f10(X1)
Y21 = f21(X2)
Y_mc = np.vstack((Y0, Y10, Y21))

var0 = torch.var(Y0).item()
var1 = torch.var(Y10).item()
var2 = torch.var(Y21).item()

n0_mc = np.sqrt(var0 / 1.)
n1_mc = np.sqrt(var1 / 3.)
n2_mc = np.sqrt(var2 / 5.)
R_mc = n0_mc*1. + n1_mc*3. + n2_mc*5.

N0_mc = T *n0_mc/R_mc
N1_mc = T *n1_mc/R_mc
N2_mc = T *n2_mc/R_mc
N_mc = np.array([N0_mc, N1_mc, N2_mc])
np.save("nsamples_mc.npy", N_mc)
