import torch
import time
from torch import nn
import torch.nn.functional as F
import numpy as np
import random
from email.mime.text import MIMEText
from utils.sampling import cifar10_noiid, cifar100_distill, cifar10_global
from utils.sampling import cifar100_noiid, tiny_imagenet_distill, cifar100_global
from utils.options import args_parser
from nets.resnets import ResNetCifar
from nets.cnn import CNNCifar
from alg.fed import Fed_Distill_hetero
if __name__=='__main__':

    np.random.seed(0)
    random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False
    
    args=args_parser()
    args.device=torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    args.path_checkpoint="checkpoint/"+args.method+'_'+args.model+'_'+str(args.alpha)+'_'+args.dataset+".pth.tar"
    print(args.path_checkpoint)
    if args.dataset=='CIFAR10':
        args.num_classes=10
        dataloader_train_dict, dataloader_test_dict, train_len_dict, test_len_dict=cifar10_noiid(args=args,root='')
        dataloader_distill=cifar100_distill(args=args,root='')
        dataloader_train_global, dataloader_test_global=cifar10_global(args=args, root='')
    elif args.dataset=='CIFAR100':
        args.num_classes=100
        dataloader_train_dict, dataloader_test_dict, train_len_dict, test_len_dict=cifar100_noiid(args=args,root='')
        dataloader_distill=tiny_imagenet_distill(args=args,root='')
        dataloader_train_global, dataloader_test_global=cifar100_global(args=args, root='')

    if 'resnet' in args.model:
        model=ResNetCifar(args, 1).to(args.device)
    elif args.model == 'cnn':
        model=CNNCifar(1, args).to(args.device)
    else:
        exit('Error: unrecognized model')

    fed=Fed_Distill_hetero(args, model, dataloader_train_dict, dataloader_test_dict, 
                        dataloader_test_global, train_len_dict, test_len_dict, dataloader_distill)
    
    fed.train()