from __future__ import print_function
import torch
import argparse

from utils import load_data
from utils import landmark_o_crop
from utils import abs_coord_to_norm

from network_and_loss import RelCoordNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='Prior modeling: RelCoordNet')

parser.add_argument('--task', type=str, default="COFW", help='Task dataset')
parser.add_argument('--random_scale', default=False, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=False, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

parser.add_argument('--batch_size', type=int, default=64, help='Batch size')

args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, _ = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    relcoordnet = RelCoordNet().to(device)
    
    c = torch.load("../../checkpoint/RelCoordNet_2000epoch.pth")
    relcoordnet.load_state_dict(c['relcoordnet'])
    
    prior_modeling(relcoordnet, train_loader)
    
    
    
def prior_modeling(relcoordnet, train_loader) :
    relcoordnet.eval()
    
    relative_landmark_z_stack = torch.zeros(29, 1, 128).to(device)
    
    with torch.no_grad(): 
        for i, (images, landmark_coords) in enumerate(train_loader):
            images, landmark_coords = images.to(device), landmark_coords.to(device)
            landmark_coords = landmark_coords.view(-1, 29, 2)
            landmark_o = landmark_o_crop(images, landmark_coords)
            
            reference_landmark_coords = landmark_coords[:, 21].repeat(29, 1).view(-1, 2)
            reference_landmark_coords = abs_coord_to_norm(reference_landmark_coords).view(-1, 2, 1, 1)
            reference_landmark_coords_input = reference_landmark_coords * torch.ones(29*args.batch_size, 2, 27, 27).to(device)
            
            relative_landmark_z, _ = relcoordnet(landmark_o.view(-1, 2, 27, 27), 
                                                 reference_landmark_coords_input)
            
            relative_landmark_z_stack = torch.cat((relative_landmark_z_stack, 
                                                   relative_landmark_z.view(29, -1, 128)), dim=1)
    
    relative_landmark_z_mean = relative_landmark_z_stack[:, 1:].mean(1)
    
    state = {
        "relative_landmark_z_mean": relative_landmark_z_mean
        }
    
    torch.save(state, "../../checkpoint/RelCoordNet_prior.pth")
    


if __name__=='__main__':
    main()
