import numpy as np
import os
from sklearn.neural_network import MLPClassifier
import argparse
from tqdm import tqdm
from copy import deepcopy
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="", help="path to dataset")
parser.add_argument("--encoder", type=str, default="")
parser.add_argument("--num_step", type=int, default=8, help="number of steps")
parser.add_argument("--num_run", type=int, default=10, help="number of runs")
parser.add_argument("--num_shot", type=int, default=16, help="number of shots")
parser.add_argument("--feature_dir", type=str, default="clip_feat", help="feature dir path")
args = parser.parse_args()

dataset = args.dataset
encoder = args.encoder
dataset_path = os.path.join(f"{args.feature_dir}", dataset, encoder)

print(f"-- loading features --------------------------------------------------------------")

train_file = np.load(os.path.join(dataset_path, "train.npz"))
train_feature, train_label = train_file["feature_list"], train_file["label_list"]
val_file = np.load(os.path.join(dataset_path, "val.npz"))
val_feature, val_label = val_file["feature_list"], val_file["label_list"]
test_file = np.load(os.path.join(dataset_path, "test.npz"))
test_feature, test_label = test_file["feature_list"], test_file["label_list"]

os.makedirs("report", exist_ok=True)

train_and_val_shot_list = {1: 2, 2: 4, 4: 8, 8: 12, 16: 20, 32: 36, 64: 68}
validation_fraction_list = {1: 0.5, 2: 0.5, 4: 0.5, 8: 1/3, 16: 1/5, 32: 1/9, 64: 4/68}
num_shot = args.num_shot

# for num_shot in [1, 2, 4, 8, 16]:
test_acc_step_list = np.zeros([args.num_run, args.num_step])
for seed in range(1, args.num_run + 1):
    np.random.seed(seed)
    print(f"-- Seed: {seed} --------------------------------------------------------------")
    # Sampling
    all_label_list = np.unique(train_label)
    selected_idx_list = []
    for label in all_label_list:
        label_collection = np.where(train_label == label)[0]
        selected_idx = np.random.choice(label_collection, size=train_and_val_shot_list[num_shot], replace=False)
        selected_idx_list.extend(selected_idx)

    fewshot_train_feature = deepcopy(train_feature[selected_idx_list])
    fewshot_train_label = deepcopy(train_label[selected_idx_list])

    vf = validation_fraction_list[num_shot]
    clf = MLPClassifier(hidden_layer_sizes=(10,), max_iter=10000, validation_fraction=vf).fit(fewshot_train_feature, fewshot_train_label)

    pred = clf.predict(test_feature)
    test_acc = 100 * sum(pred == test_label) / len(pred)
    print(test_acc)
