# Needed libraries:
# - numpy
# - sklearn
# - tqdm
# - matplotlib
# - scipy

import numpy as np
import sklearn.cluster
from tqdm import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import scipy

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

n, k, d1, d2 = 10**4, 10**2, 10**3, 10**1
reps = 5
eps2 = 0


def code_and_learn(T, X, Y, k, top=1, method='kmeans'):
    d1, d2 = T.shape
    if method == "kmeans":
        code = sklearn.cluster.KMeans(n_clusters=k).fit(T).cluster_centers_
    elif method == "normal":
        code = np.random.randn(k, d2)
    elif method == "ortho":
        # code = np.random.randn(k, d2)
        code = scipy.stats.ortho_group.rvs(k)[:, :d2]
    elif method == "self":
        G = np.random.randn(k, d1) / k**.5
        code = G @ T
    elif method == "noise":
        H = np.zeros((d1, k))
        H[:, :d2] = T
        H[:, d2:] = np.random.randn(d1, k-d2)
        M, *_ = np.linalg.lstsq(X @ H, Y)
        return H @ M
    if top == 'exp':
        # I tried no normalization and dividing by k**.5 instead,
        # but they both eventually stalled as I inreased the steps.
        # I think this normalization also matches what's used in
        # attention.
        H = np.exp(T @ code.T / d2**.5)
        H /= H.sum(axis=1, keepdims=True)
    elif top == 'raw':
        # This sucks
        H = T @ code.T
        H /= (H**2).sum(axis=1, keepdims=True)**.5
    else:
        r = np.argpartition(-T @ code.T, kth=top, axis=1)[:,:top]
        H = np.zeros((d1, k))
        for i in range(top):
            H[np.arange(d1), r[:,i]] = 1
    M, *_ = np.linalg.lstsq(X @ H, Y)
    return H @ M


def shift_sal():
    # shape = (type, method, sal, rep)
    res = defaultdict(list)
    all_labels = []
    nsals = 100
    for _ in tqdm(range(reps)):
        T = np.random.randn(d1, d2)
        X = np.random.randn(n, d1)
        #Y = X @ T + (d1/d2)**.5 * np.random.randn(n, d2)
        Y = X @ T

        def add_loss(Tt, sal, label):
            if label not in all_labels:
                all_labels.append(label)
            res["T", label, sal].append(t := ((T - Tt)**2).mean() )
            res["X", label, sal].append(x := ((X @ Tt - Y)**2).mean() )
            print("X", label, sal, x)
            #print("T", label, sal, t)

        # Learning with "true" clustering
        Tt = code_and_learn(T, X, Y, k, top=1, method='kmeans')
        for sal in range(nsals+1):
            add_loss(Tt, sal, "KMeans on true T")

        for splits in ['noise', 2, 4, 8, 16, 32]:
            # Sampling random H matrix
            H = np.zeros((d1, k))
            if type(splits) is str:
                H = np.random.randn(d1, k)
            else:
                for _ in range(splits):
                    r = np.random.randint(k, size=(d1,))
                    H[np.arange(d1), r] += 1
            M, *_ = np.linalg.lstsq(X @ H, Y)
            T_sketch = H @ M
            methods = ['noise'] if type(splits) is str else ['kmeans']
            for method in methods:
                Tt = T_sketch
                for sal in range(nsals):
                    add_loss(Tt, sal, "Noise method" if method == 'noise ' else f"{splits} hashes")
                    if splits == 'exp':
                        nent = np.nansum(np.log(1/Tt)*Tt, axis=1).mean()
                        print(sal, nent)
                    if splits == 'adj':
                        top = 1 + 50//(1 + sal)
                    else: top = splits
                    Tt = code_and_learn(Tt, X, Y, k, top=top, method=method)
                #add_loss(Tt, nsals, f"{method} {splits=}")
                add_loss(Tt, nsals, "Noise method" if method == 'noise ' else f"{splits} hashes")

    plt.title(f"n={n}, d1={d1}, d2={d2}, k={k}")
    sals = list(range(nsals+1))
    for i, label in enumerate(all_labels):
        print(label)
        y = np.array([res["X", label, sal] for sal in range(nsals+1)])
        s = y.std(axis=1) / reps**.5
        #plt.fill_between(sals, y.min(axis=1), y.max(axis=1), alpha=.3, color=f'C{i}')
        plt.fill_between(sals, y.mean(axis=1)-s, y.mean(axis=1)+s, alpha=.3, color=f'C{i}')
        plt.plot(sals, y.mean(axis=1), label=label, color=f'C{i}')
    plt.xlabel('Cluster + Learn iterations')
    plt.ylabel('MSE')
    plt.xscale('log')
    plt.legend()
    plt.show()


shift_sal()
