from __future__ import print_function
import torch
import argparse

from utils import load_data
from utils import landmark_o_crop

from network_and_loss import FeatNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='Prior modeling: FeatNet')

# Hyperparameters 
parser.add_argument('--task', type=str, default="COFW", help='Task dataset')
parser.add_argument('--random_scale', default=False, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=False, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')

args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, test_loader = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    featnet_leye = FeatNet().to(device)
    featnet_reye = FeatNet().to(device)
    featnet_mouth = FeatNet().to(device)
    featnet_nose = FeatNet().to(device)
    featnet_jaw = FeatNet().to(device)
    
    c = torch.load("../../Checkpoint/FeatNet_2000epoch.pth")
    featnet_leye.load_state_dict(c['featnet_leye'])
    featnet_reye.load_state_dict(c['featnet_reye'])
    featnet_mouth.load_state_dict(c['featnet_mouth'])
    featnet_nose.load_state_dict(c['featnet_nose'])
    featnet_jaw.load_state_dict(c['featnet_jaw'])
    
    prior_modeling(featnet_leye, featnet_reye, featnet_mouth, featnet_nose, featnet_jaw, train_loader)
    
    
    
def prior_modeling(featnet_leye, featnet_reye, featnet_mouth, 
                   featnet_nose, featnet_jaw, train_loader) :
    featnet_leye.eval()
    featnet_reye.eval()
    featnet_mouth.eval()
    featnet_nose.eval()
    featnet_jaw.eval()
    
    leye_z_stack = torch.zeros(11, 1, 128).to(device)
    reye_z_stack = torch.zeros(11, 1, 128).to(device)
    mouth_z_stack = torch.zeros(20, 1, 128).to(device)
    nose_z_stack = torch.zeros(9, 1, 128).to(device)
    jaw_z_stack = torch.zeros(17, 1, 128).to(device)
    
    with torch.no_grad(): 
        for i, (images, landmark_coords) in enumerate(train_loader):
            images, landmark_coords = images.to(device), landmark_coords.to(device)
            landmark_coords = landmark_coords.view(-1, 29, 2)
            landmark_o = landmark_o_crop(images, landmark_coords)
            
            leye_o = torch.cat((landmark_o[17:22], landmark_o[36:42]), dim=0).view(-1, 6, 27, 27)
            reye_o = torch.cat((landmark_o[22:27], landmark_o[42:48]), dim=0).view(-1, 6, 27, 27)
            mouth_o = landmark_o[48:68]
            nose_o = landmark_o[27:36]
            jaw_o = landmark_o[:17]
            
            leye_z = featnet_leye(leye_o).view(11, -1, 128)
            reye_z = featnet_reye(reye_o).view(11, -1, 128)
            mouth_z = featnet_mouth(mouth_o).view(20, -1, 128)
            nose_z = featnet_mouth(nose_o).view(9, -1, 128)
            jaw_z = featnet_mouth(jaw_o).view(17, -1, 128)
            
            leye_z_stack = torch.cat((leye_z_stack, leye_z), dim=1)
            reye_z_stack = torch.cat((reye_z_stack, reye_z), dim=1)
            mouth_z_stack = torch.cat((mouth_z_stack, mouth_z), dim=1)
            nose_z_stack = torch.cat((nose_z_stack, nose_z), dim=1)
            jaw_z_stack = torch.cat((jaw_z_stack, jaw_z), dim=1)
            
        
    leye_z_mean = leye_z_stack[:, 1:].mean(1)
    reye_z_mean = reye_z_stack[:, 1:].mean(1)
    mouth_z_mean = mouth_z_stack[:, 1:].mean(1)
    nose_z_mean = nose_z_stack[:, 1:].mean(1)
    jaw_z_mean = jaw_z_stack[:, 1:].mean(1)
    
    state = {
        'leye_z_mean': leye_z_mean,
        'reye_z_mean': reye_z_mean,
        'mouth_z_mean': mouth_z_mean,
        'nose_z_mean': nose_z_mean,
        'jaw_z_mean': jaw_z_mean,
        }
    
    torch.save(state, "../../Checkpoint/FeatNet_prior.pth")
    


if __name__=='__main__':
    main()
