import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import skdim
import numpy as np
import argparse
import torch
from torch.nn.utils import parameters_to_vector

from model import LeNet, TwoLayerNeuralNet, ThreeLayerNeuralNet, FourLayerNeuralNet

np.random.seed(0)

def estimate_dimension(data):
    print(data.shape)

    TwoNN = skdim.id.TwoNN().fit(data).dimension_
    print(f"TwoNN={TwoNN}")

    lPCA = skdim.id.lPCA().fit(data).dimension_
    print(f"lPCA={lPCA}")

    return {"TwoNN": TwoNN, "lPCA": lPCA}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--sample_num', type=int)
    parser.add_argument('--data_num', type=int)

    args = parser.parse_args()
    model_name = args.model_name
    sample_num = args.sample_num
    data_num = args.data_num

    data = []
    if model_name == "lenet":
        model = LeNet(1)
    if model_name == "four_layer":
        model = FourLayerNeuralNet()
    if model_name == "three_layer":
        model = ThreeLayerNeuralNet()
    if model_name == "two_layer":
        model = TwoLayerNeuralNet()

    for seed in range(sample_num):
        checkpoint = torch.load(f"large_experiment_result/sgd/{model_name}/{data_num}/model_weights_{seed}.pth", map_location=torch.device('cpu'))
        state_dict = checkpoint["model_state"]
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        with torch.no_grad():
            vec = parameters_to_vector(model.parameters()).cpu().numpy()
        data.append(vec)
    
    data = np.array(data, dtype='float64')
    estimated_dim = estimate_dimension(data)
    
    print(estimated_dim)


if __name__ == "__main__":
    main()
