import os
import argparse
import numpy as np
import torch
import utils
import random
import torchvision
from utils import load_pretrain, MLP
from utils import classify, generate_FP, verify_FP

def get_args():
    parser = argparse.ArgumentParser()
    # Downstream settings
    parser.add_argument('--arch', type=str, default='res18',
                        help='Architecture of the victim model')
    parser.add_argument('--dataset', type=str, default='stl10',
                        help='Which dataset to evaluate')
    parser.add_argument('--pretrain_style', type=str, default='supervised',
                        help='The training method of pre-trained models')
    parser.add_argument('--FT_mode', type=str, default='FTLL',
                        help='The fine-tuning method for downstream tasks')
    parser.add_argument('--device', type=str, default='cuda',
                        choices=['cuda', 'cpu', 'mps'],
                        help='Supported devices')
    parser.add_argument('--random', type=float, default=42,
                        help='random seed')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='The amount of training data per iteration')
    parser.add_argument('--epochs', type=int, default=30,
                        help='Training epochs')
    parser.add_argument('--lr', type=float, default=0.005, # 0.0005 for FTAL
                        help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='The parameter of the Adam optimizer')
    
    # Fingerprint settings
    parser.add_argument('--num_anchor', type=int, default=10,
                        help='The number of anchors')
    parser.add_argument('--num_member', type=int, default=10,
                        help='The number of mermbers in a group')
    parser.add_argument('--num_group', type=int, default=10,
                        help='The number of groups')
    parser.add_argument('--fp_path', type=str, default='./Fingerprints',
                        help='The path where the fingerprint sample is saved')
    parser.add_argument('--num_iters', type=int, default=200,
                        help='The number of iterations')
    parser.add_argument('--member_lr', type=float, default=1e-3,
                        help='The learning rate of members')
    parser.add_argument('--eps_l2', type=float, default=16/255,
                        help='L2 restriction of perturbations')
    parser.add_argument('--member_optimizer', type=str, default='sgd',
                        help='Optimizer of members')
    parser.add_argument('--cluster_method', type=str, default='sc',
                        help='Cluster method for deciding anchors')
    
    parser.add_argument('--sus_num_classes', type=int, default=100,
                        help='The number of classes of the suspect model')
    parser.add_argument('--sus_arch', type=str, default='res50',
                        help='Architecture of the suspect model')
    parser.add_argument('--sus_model_dir', type=str, default='./cifar100_res50_FTLL',
                        help='Architecture of the suspect model')
    parser.add_argument('--fp_dir', type=str, default='./Fingerprints/cifar100_res50_FTLL',
                        help='The path where the fingerprint sample is located')


    return parser.parse_args()

args = get_args()

if args.dataset == 'cifar100':
    args.num_classes = 100
elif args.dataset == 'stl10':
    args.num_classes = 10
elif args.dataset == 'gtsrb':
    args.num_classes = 43
    


def main(args):
    # Train downstream task
    encoder = load_pretrain(args.pretrain_style, args.device, args.arch)
    MLP_layers = MLP(args.num_classes)
    head = classify(args.dataset, encoder, MLP_layers, args, mode=args.FT_mode)

    # Extract Fingerprints
    generate_FP(args)

    # Verify
    matching_rate = verify_FP(args.sus_model_dir, args.fp_dir, args)


if __name__ == '__main__':
    utils.random_seed(args.random)
    main(args)
