import argparse
from copy import deepcopy
from functools import cache
import math
from pathlib import Path
import os

import numpy as np
from scipy.cluster.vq import kmeans
import torch
import torch.nn as nn

import nsn_tools
from src.utils import set_all_seeds

D = 128
N = 8
M = 256

def get_kmeans_codebook(abs=False):
    data = np.random.normal(size=(16384, N))
    if abs:
        data = abs(data)
    kmeans_codebook = kmeans(data, M)[0]
    return torch.from_numpy(kmeans_codebook)
    

def get_learned_codebook(abs=False):
    init_codebook = get_kmeans_codebook()

    class CodeBookLearner(nn.Module):
        def __init__(self, init_codebook):
            super().__init__()
            self.codebook = nn.Parameter(init_codebook)
            self.cossim = nn.CosineSimilarity(dim=-1)

        def forward(self, x):
            index = nsn_tools.dist_argmin_half(x.half(), self.codebook.half())
            res = self.codebook[index.int()]
            return 1 - self.cossim(res.reshape(-1, D), x.reshape(-1, D)).mean()
        
    step = 10000
    codebook_model = CodeBookLearner(init_codebook).cuda()
    optimizer = torch.optim.Adam(codebook_model.parameters(), lr=0.001)

    losses = []
    best_loss = 1e9
    best_codebook = init_codebook
    for i in range(step):
        data = torch.randn(100000, D, device="cuda")
        if abs:
            data = data.abs()
        data = data.reshape(-1, N)
        optimizer.zero_grad()
        loss = codebook_model(data)
        losses.append(loss.item())
        if (i+1) % 100 == 0:
            mean_loss = np.mean(losses)
            losses = []
            if mean_loss < best_loss:
                best_loss = mean_loss
                best_codebook = deepcopy(codebook_model.codebook.data.detach().cpu())
                print("Best", mean_loss, optimizer.param_groups[0]['lr'])
            else:
                print(mean_loss, optimizer.param_groups[0]['lr'])
        loss.backward()
        optimizer.step()

    return best_codebook


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, default="learned", choices=["learned", "kmeans"])
    parser.add_argument("--output_path", type=str, default="codebooks/codebook.pt")
    parser.add_argument("--save", action="store_true")
    parser.add_argument("--abs", action="store_true")
    set_all_seeds(42)
    args = parser.parse_args()
    os.makedirs(Path(args.output_path).parent, exist_ok=True)
    if args.method == "learned":
        codebook = get_learned_codebook(args.abs)
    elif args.method == "kmeans":
        codebook = get_kmeans_codebook(args.abs)
    else:
        raise NotImplementedError(f"No such method: {args.method}")

    if args.save:
        torch.save(codebook, args.output_path)


if __name__ == "__main__":
    main()
