from __future__ import print_function
import torch
import argparse

from utils import load_data
from utils import landmark_o_crop

from network_and_loss import FeatNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='Prior modeling: FeatNet')

# Hyperparameters 
parser.add_argument('--task', type=str, default="COFW", help='Task dataset')
parser.add_argument('--random_scale', default=False, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=False, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')

args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, test_loader = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    featnet_leye = FeatNet().to(device)
    featnet_reye = FeatNet().to(device)
    featnet_others = FeatNet().to(device)
    
    c = torch.load("../../Checkpoint/FeatNet_2000epoch.pth")
    featnet_leye.load_state_dict(c['featnet_leye'])
    featnet_reye.load_state_dict(c['featnet_reye'])
    featnet_others.load_state_dict(c['featnet_others'])
    
    prior_modeling(featnet_leye, featnet_reye, featnet_others, train_loader)
    
    
    
def prior_modeling(featnet_leye, featnet_reye, featnet_others, train_loader) :
    featnet_leye.eval()
    featnet_reye.eval()
    featnet_others.eval()
    
    leye_z_stack = torch.zeros(9, 1, 128).to(device)
    reye_z_stack = torch.zeros(9, 1, 128).to(device)
    others_z_stack = torch.zeros(11, 1, 128).to(device)
    
    leye_idx = [0, 2, 4, 5, 8, 10, 12, 13, 16]
    reye_idx = [1, 3, 6, 7, 9, 11, 14, 15, 17]
    
    with torch.no_grad(): 
        for i, (images, landmark_coords) in enumerate(train_loader):
            images, landmark_coords = images.to(device), landmark_coords.to(device)
            landmark_coords = landmark_coords.view(-1, 29, 2)
            landmark_o = landmark_o_crop(images, landmark_coords)
            leye_z = featnet_leye(landmark_o[leye_idx].view(-1, 2, 27, 27)).view(9, -1, 128)
            reye_z = featnet_reye(landmark_o[reye_idx].view(-1, 2, 27, 27)).view(9, -1, 128)
            others_z = featnet_others(landmark_o[18:].view(-1, 2, 27, 27)).view(11, -1, 128)
            
            leye_z_stack = torch.cat((leye_z_stack, leye_z), dim=1)
            reye_z_stack = torch.cat((reye_z_stack, reye_z), dim=1)
            others_z_stack = torch.cat((others_z_stack, others_z), dim=1)
            
    leye_z_mean = leye_z_stack[:, 1:].mean(1)
    reye_z_mean = reye_z_stack[:, 1:].mean(1)
    others_z_mean = others_z_stack[:, 1:].mean(1)
    
    state = {
        'leye_z_mean': leye_z_mean,
        'reye_z_mean': reye_z_mean,
        'others_z_mean': others_z_mean
        }
    
    torch.save(state, "../../Checkpoint/FeatNet_prior.pth")
    


if __name__=='__main__':
    main()
