import argparse
import numpy as np
import torch

from data import get_dataset, ParamDiffAug
from models.wrapper import get_model
from utils import evaluate, evaluate_debug, Logger, default_args
from generator import SyntheticImageGenerator

def main(args):
    args.device = torch.device(f"cuda:{args.gpu_id}")
    args.dsa_param = ParamDiffAug()
    if args.dataset == 'SVHN':
        if args.mix_p > 0.0 or args.mixup:
            args.dsa_strategy = 'color_crop_scale_rotate'
        else:
            args.dsa_strategy = 'color_crop_cutout_scale_rotate'
    else:
        if args.mix_p > 0.0 or args.mixup:
            args.dsa_strategy = 'color_crop_flip_scale_rotate'  
        else:
            args.dsa_strategy = 'color_crop_cutout_flip_scale_rotate'  
    ''' data set '''
    channel, im_size, num_classes, normalize, _, _, testloader = get_dataset(args.dataset, args.data_path, False)
    args.num_classes = num_classes

    ''' save path '''
    save_path = f'{args.save_path}/{args.dataset}/{args.exp_name}'
    state_dict = torch.load(f'{save_path}/generator.pth', map_location='cpu')

    ''' initialize '''
    generator = SyntheticImageGenerator(
        num_classes, im_size, args.num_seed_vec, args.num_decoder, args.hdims,
        args.kernel_size, args.stride, args.padding)
    del generator.encoders
    generator.load_state_dict(state_dict)
    generator = generator.to(args.device)

    ''' Evaluate synthetic data '''
    image_syn, label_syn = generator.get_all_cpu()
    image_syn, label_syn = image_syn.detach(), label_syn.detach()
    del generator    
    
    ''' model eval pool'''
    if args.model_eval_pool is None:
        args.model_eval_pool = [args.model]
    else:
        args.model_eval_pool = args.model_eval_pool.split("_")
    accs_all_exps = dict() # record performances of all experiments
    for key in args.model_eval_pool:
        accs_all_exps[key] = []                   
                
        ''' Evaluate all model_eval '''
        for model_eval in args.model_eval_pool:
            for _ in range(args.num_eval):
                net_eval = get_model(args, model_eval, channel, num_classes, im_size).to(args.device) # get a random model
                _, acc = evaluate(args, net_eval, image_syn, label_syn, testloader, normalize)
                accs_all_exps[model_eval].append(acc)


    print('\n==================== Final Results ====================\n')
    for key in args.model_eval_pool:
        accs = accs_all_exps[key]        
        print('Train on %s, Evaluate on %s for %d: mean  = %.2f%%  std = %.2f%%'%(args.model, key, len(accs), np.mean(accs), np.std(accs)))
        with open(f'{args.save_path}/{args.dataset}/{args.exp_name}/{key}_final_results.txt', 'w') as f:
            f.write(f'mean = {np.mean(accs)}, std = {np.std(accs)}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parameter Processing')   

    # data
    parser.add_argument('--data_path', type=str, default='ANONYMIZED')
    parser.add_argument('--dataset', type=str, default='ANONYMIZED')

    # save
    parser.add_argument('--save_path', type=str, default='results')
    parser.add_argument('--exp_name', type=str, default=None)

    # repeat
    parser.add_argument('--num_eval', type=int, default=3)
    
    # training
    parser.add_argument('--model', type=str, default='ResNet10AP')
    parser.add_argument('--ipc', type=int, default=10)
    parser.add_argument('--num_seed_vec', type=int, default=16)
    parser.add_argument('--num_decoder', type=int, default=12)

    # evaluation
    parser.add_argument('--model_eval_pool', type=str, default=None)    
    parser.add_argument('--epoch', type=int, default=200)
    parser.add_argument('--eval_opt', type=str, default="sgd")
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--wd', type=float, default=5e-4)
    parser.add_argument('--init_std', type=float, default=0.0)
    parser.add_argument('--batch', type=int, default=128)
    parser.add_argument('--mix_p', type=float, default=0.0)
    parser.add_argument('--beta', type=float, default=1.0)
    parser.add_argument('--warmup', action='store_true')
    parser.add_argument('--not_aug', action='store_true')
    parser.add_argument('--mixup', action='store_true')

    parser.add_argument('--gpu-id', type=int, default=0)  

    args = parser.parse_args()
    default_args(args)

    if args.exp_name is None:
        args.exp_name = f'{args.model}_{args.ipc}_{args.num_seed_vec}_{args.num_decoder}'
    print(args.exp_name)

    main(args)


