import numpy as np
import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from backbone.model_resnet import ResNet_50, ResNet_152
from backbone.MobileFaceNets import MobileFaceNet
from head.metrics import CosFace, ArcFace
from loss.focal import FocalLoss
from util.utils import separate_resnet_bn_paras, warm_up_lr, schedule_lr, AverageMeter, accuracy
from util.data_utils_balanced import prepare_data
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
random.seed(222)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('--data_train_root', default='/')
    parser.add_argument('--data_test_root', default='/')
    parser.add_argument('--demographics', default='/CelebA_demographics.txt', help='dict gender:idx')
    parser.add_argument('--backbone_name', default='')
    parser.add_argument('--head_name', default='')
    parser.add_argument('--train_loss', default='Focal', type=str)

    parser.add_argument('--groups_to_modify', default= ['male', 'female'], type=str, nargs='+')
    parser.add_argument('--p_identities', default=[1.0, 1.0], type=float, nargs='+')
    parser.add_argument('--p_images', default=[1.0, 1.0], type=float, nargs='+')
    parser.add_argument('--min_num_images', default=3, type=int)

    parser.add_argument('--batch_size', default=512, type=int)
    parser.add_argument('--input_size', default=[112,112], type=int)
    parser.add_argument('--embedding_size', default=512, type=int)
    parser.add_argument('--weight_decay', default=5e-4, type=float)
    parser.add_argument('--momentum', default=0.9, type=float)
    parser.add_argument('--mean', default=[0.5, 0.5, 0.5], type=int)
    parser.add_argument('--std', default=[0.5, 0.5, 0.5], type=int)
    parser.add_argument('--stages', default=[35, 65, 95], type=int)
    parser.add_argument('--num_workers', default=4, type=int)

    parser.add_argument('--lr', default=0.1, type=float)
    parser.add_argument('--num_epoch', default=100, type=int)
    parser.add_argument('--gpu_id', default=[0], type=int, nargs='+', help='gpu id')
    parser.add_argument('--name', default='', type=str)
    parser.add_argument('--dataset', default='', type=str)



    args = parser.parse_args()

    p_images = {args.groups_to_modify[i]:args.p_images[i] for i in range(len(args.groups_to_modify))}
    p_identities = {args.groups_to_modify[i]:args.p_identities[i] for i in range(len(args.groups_to_modify))}
    args.p_images = p_images
    args.p_identities = p_identities


    ####################################################################################################################################
    # ======= data, model and test data =======#

    dataloaders, num_class, demographic_to_labels_train, demographic_to_labels_test = prepare_data(args)



    backbone_dict = {'MobileFaceNet': MobileFaceNet(embedding_size=512, out_h=7, out_w = 7),
                     'ResNet_50': ResNet_50(args.input_size),
                     'ResNet_152': ResNet_152(args.input_size)
                     }

    head_dict = {'ArcFace': ArcFace(in_features = args.embedding_size, out_features = num_class, device_id = args.gpu_id),
                 'CosFace': CosFace(in_features = args.embedding_size, out_features = num_class, device_id = args.gpu_id)}

    backbone = backbone_dict[args.backbone_name]
    head = head_dict[args.head_name]
    train_criterion = FocalLoss(elementwise=True)

    ####################################################################################################################
    # ======= optimizer =======#

    backbone_paras_only_bn, backbone_paras_wo_bn = separate_resnet_bn_paras(backbone)
    _, head_paras_wo_bn = separate_resnet_bn_paras(head)
    optimizer = optim.SGD([{'params': backbone_paras_wo_bn + head_paras_wo_bn, 'weight_decay': args.weight_decay},
                               {'params': backbone_paras_only_bn}], lr=args.lr, momentum=args.momentum)

    backbone, head = backbone.to(device), head.to(device)
    backbone = nn.DataParallel(backbone).to(device)

    ####################################################################################################################################
    # ======= train & validation & save checkpoint =======#
    num_epoch_warm_up = args.num_epoch // 25  # use the first 1/25 epochs to warm up
    num_batch_warm_up = len(dataloaders['train']) * num_epoch_warm_up  # use the first 1/25 epochs to warm up
    epoch = 0
    batch = 0
    ####################################################################################################################################
    # ======= training =======#

    while epoch <= args.num_epoch:
        backbone.train()  # set to training mode
        head.train()
        meters = {}
        meters['loss'] = AverageMeter()
        meters['top5'] = AverageMeter()

        if epoch in args.stages:  # adjust LR for each training stage after warm up, you can also choose to adjust LR manually (with slight modification) once plaueau observed
            schedule_lr(optimizer)

        for inputs, labels, sens_attr in tqdm(iter(dataloaders['train'])):

            if batch + 1 <= num_batch_warm_up:  # adjust LR for each training batch during warm up
                warm_up_lr(batch + 1, num_batch_warm_up, args.lr, optimizer)

            inputs, labels = inputs.to(device), labels.to(device).long()
            features = backbone(inputs)
            outputs = head(features, labels)
            loss = train_criterion(outputs, labels).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure accuracy and record loss
            prec1, prec5 = accuracy(outputs.data, labels, topk=(1, 5))
            meters['loss'].update(loss.data.item(), inputs.size(0))
            meters['top5'].update(prec5.data.item(), inputs.size(0))

            batch += 1
        epoch += 1


