import numpy as np
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
import argparse
import torch
import tqdm
import pickle
from al_query import our_query, random_query, our_query_fea, GAAL_query, ACGAN_query, RandomTxt_query, ActiveGAN_query
from model_train import MLP, train_classifier, get_cla_clip_models, get_data_loader, get_ext_dataloader, get_train_lr, get_features, train_with_fea_val
from support.utils import check_all_class, init_seeds, load_gen_fea
from support.inat21_supclass import inat21_class_to_idx, inat21_idx_to_class

device = "cuda:0" if torch.cuda.is_available() else "cpu"


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        choices=[
            "cifar10",
            "inaturalist21",
        ],
    )
    parser.add_argument(
        "--method",
        type=str,
        default="ours_gen",
        choices=[
            "randomtxt",
            "ours_gen",
            "our_query_fea",
        ],
    )
    parser.add_argument("--iter", type=int, default=0)
    parser.add_argument("--valid_size_multiplier", type=int, default=10)
    parser.add_argument("--use_cache", action="store_true")
    args = parser.parse_args()

    if args.dataset == 'inaturalist21':
        NCLASSES = 11
    else:
        NCLASSES = 10
    
    args.query_size = NCLASSES
    validset_size = NCLASSES * args.valid_size_multiplier  if args.dataset != "inaturalist21" else NCLASSES*100
    method_dir = f"{args.method}" if args.method not in ["ours_gen", "our_query_fea"] else f"{args.method}_{validset_size}"
    output_dir = f"./model_save/{args.dataset}/{method_dir}"
    
    if args.use_cache:
        if os.path.exists(os.path.join(output_dir, str(args.iter+1), 'acc.txt')):
            print(f"Already trained model exists in {output_dir}/{args.iter}.")
            exit(0)

    os.makedirs(os.path.join(output_dir, str(args.iter)), exist_ok=True)
    os.makedirs(os.path.join(output_dir, str(args.iter+1)), exist_ok=True)
    cla_model, clip_model, clip_preprocess = get_cla_clip_models(NCLASSES)
    init_seeds(0)

    lab_loader, unlab_loader, val_loader, test_loader, lab_idx, unlab_idx, valid_idx = \
        get_data_loader(iter=args.iter, dataset=args.dataset, output_dir=output_dir, clip_preprocess=clip_preprocess, 
                        init_lab_size=NCLASSES*10,
                        validset_size=validset_size,
                        split_valid_from_train=True)
    print(f"Initial labeled set size: {len(lab_idx)}")
    print(f"Initial unlabeled set size: {len(unlab_idx)}")
    print(f"Validation set size: {len(valid_idx)}")
    print(f"Test set size: {len(test_loader.dataset)}")

    if args.iter == 0:
        cflg1 = check_all_class(lab_loader, NCLASSES)
        cflg2 = check_all_class(val_loader, NCLASSES)
        if not (cflg1 and cflg2):
            raise ValueError("There are some classes with 0 instances in initial labeled set or validation set.")

    if "cifar" in args.dataset:
        class_to_idx = lab_loader.dataset.class_to_idx
    elif "inaturalist21" in args.dataset:
        class_to_idx = inat21_class_to_idx
    else:
        raise NotImplementedError()

    # train initial model
    train_lr = get_train_lr(lab_loader)
    ext_train_labels = None
    train_features, train_labels = get_features(clip_model, lab_loader)
    if args.iter > 0 and (args.method == "ours_gen" or args.method == "randomtxt"):
        ext_train_features, ext_train_labels = get_features(clip_model, get_ext_dataloader(args.dataset, method_dir_name=method_dir, iter=args.iter, query_batch_size=args.query_size, clip_preprocess=clip_preprocess, class_to_idx=class_to_idx, ext_type=args.method))
        train_features = np.vstack((train_features, ext_train_features))
        train_labels = np.hstack((train_labels, ext_train_labels))
        print(len(train_labels))
    if args.iter > 0 and args.method in ["our_query_fea"]:
        ext_train_features, ext_train_labels = load_gen_fea(output_dir, args.iter)
        train_features = np.vstack((train_features, ext_train_features))
        train_labels = np.hstack((train_labels, ext_train_labels))
    if os.path.exists(os.path.join(f"./model_save/{args.dataset}", 'test_features.npy')):
        test_features = np.load(os.path.join(f"./model_save/{args.dataset}", 'test_features.npy'))
        test_labels = np.load(os.path.join(f"./model_save/{args.dataset}", 'test_labels.npy'))
    else:
        test_features, test_labels = get_features(clip_model, test_loader)
        np.save(os.path.join(f"./model_save/{args.dataset}", 'test_features.npy'), test_features)
        np.save(os.path.join(f"./model_save/{args.dataset}", 'test_labels.npy'), test_labels)
    cla_model, test_loss, acc, loss_fun, train_fea, train_lab, test_fea, test_label = train_with_fea_val(
        cla_model, train_features, train_labels, test_features, test_labels, train_lr)
    np.savetxt(os.path.join(output_dir, str(args.iter), 'acc.txt'), [acc], fmt="%f")

    # AL loop
    if args.method == "ours_gen":
        selected_idx = our_query(args.dataset, lab_idx, unlab_idx, lab_loader, unlab_loader, val_loader,
           clip_model, cla_model, loss_fun, train_fea, train_lab, args.query_size, class_to_idx_dict=class_to_idx,
           gen_class_txt_num=NCLASSES, save_img_flag=True, save_dir=os.path.join(output_dir, str(args.iter+1)), use_cache=args.use_cache, 
           query_iter=args.iter, ext_labels=ext_train_labels, method_dir_name=method_dir)
    elif args.method == "our_query_fea":
        selected_idx = our_query_fea(lab_loader, val_loader, clip_model, cla_model, 
              loss_fun, train_fea, train_lab, class_to_idx, gen_class_txt_num=NCLASSES,
              use_cache=args.use_cache, save_dir=os.path.join(output_dir, str(args.iter+1)))
    elif args.method == "randomtxt":
        selected_idx = RandomTxt_query(class_to_idx, NCLASSES, args.dataset, str(args.iter), method_dir_name=method_dir)
    else:
        raise NotImplementedError()
    
    assert np.intersect1d(selected_idx, lab_idx).size == 0
    assert np.intersect1d(selected_idx, unlab_idx).size == len(selected_idx)
    lab_idx = np.hstack((lab_idx, selected_idx))
    unlab_idx = np.setdiff1d(unlab_idx, lab_idx)
    np.random.shuffle(unlab_idx)
    # save results
    np.savetxt(os.path.join(output_dir, str(args.iter+1), 'valid_idx.txt'), valid_idx, fmt="%d")
    np.savetxt(os.path.join(output_dir, str(args.iter+1), 'unlab_idx.txt'), unlab_idx, fmt="%d")
    np.savetxt(os.path.join(output_dir, str(args.iter+1), 'lab_idx.txt'), lab_idx, fmt="%d")
    if args.method not in ["ours_gen", "our_query_fea"]:
        np.savetxt(os.path.join(output_dir, str(args.iter+1), 'selected_idx.txt'), selected_idx, fmt="%d")
