from functions import *
from operators import *
from dists import get_params, generate
from CNN_functions import *
import warnings
warnings.filterwarnings("ignore")

rng = np.random.RandomState(42)

links = ['Linear', 'Non-linear', 'FCN', 'CNN']
dists = ['IID Gaussian', 't', 'Gamma', 'Correlated Gaussian']

which_link = 1
which_dist = 1
link = links[which_link-1]
dist = dists[which_dist-1]

shape = (28, 28)
P, D = shape
N = 2000
noise_level = 0.1

kernel_size = (4, 4)
stride = kernel_size
R = 3
d = np.prod(kernel_size)

M = rng.normal(size=(d, d))
U, S, Vt = np.linalg.svd(M)
Theta = U[:, :R]
B = np.array([br.reshape(kernel_size) for br in Theta.T])

params, params_for_score = get_params(dist, shape, kernel_size, stride)
print('Link:', link)

X, score_function = generate(N, shape, dist, params)
Features = np.array([[conv_matrix(Xi, Br, stride) for Br in B] for Xi in X])
Y, _ = Index_model(Features, link=link, noise_level=noise_level)

RX = np.array([R_opt(Xi, blockshape=kernel_size, stride=stride) for Xi in X])

Theta_hat = estimate_Theta(RX, Y, score_function, R=R, truncate=False, params=params_for_score)
distance = compute_distance(Theta, Theta_hat)
print(f'Distance: {distance:.3f}')
