import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from nn.cola_nn import define_multi_vec
from learning.fns import gen_core_info
from learning.fns import gen_cores
from learning.fns import get_einsum_expr
from ops.operators import EinOpVec

Ds = [10, 50, 100, 500, 1000]
# Ds = [10, 50]
frob = np.zeros(len(Ds))
for jdx, D in enumerate(Ds):
    d_in, d_out = D, D
    core_info = gen_core_info(total_cores=2 + 2)
    vec = define_multi_vec(d_in, d_out)
    cores, active_cores = gen_cores(vec, core_info)
    mult = 1 / (D**0.25)
    cores = [mult * cor for cor in cores]
    ein_expr = get_einsum_expr(active_cores, core_info)
    shapes = (cores[0].shape, cores[-1].shape)
    E = EinOpVec(cores[1:-1], ein_expr, shapes, allow_padding=False)
    frob[jdx] = np.linalg.norm(E.to_dense(), ord="fro")

    # mult = 1 / np.sqrt(D)
    # A = mult * np.random.normal(size=(D, D))
    # frob[jdx] = np.linalg.norm(A, ord="fro")

sns.set(style="whitegrid", font_scale=2.0, rc={"lines.linewidth": 3.0})
sns.set_palette("Set2")
plt.figure(dpi=100, figsize=(20, 10))
plt.plot(Ds, frob, label="case")
# plt.plot(Ds, Ds, label="y=x")
plt.plot(Ds, np.sqrt(Ds), label="y=sqrt(x)")
plt.ylabel("Norm")
plt.xlabel('Size')
plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
plt.show()
