from __future__ import print_function
import torch
import argparse

from utils import load_data
from utils import anchor_random_crop
from utils import save_model
from utils import scheduler_step
from network_and_loss import FeatNet
from network_and_loss import PWConLoss

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='FeatNet training with PWConLoss')

# Hyperparameters 
parser.add_argument('--task', type=str, default="COFW", help='Task dataset')
parser.add_argument('--random_scale', default=True, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=True, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=2000, help='Maximum number of epochs')
parser.add_argument('--learning_rate', type=float, default=1E-2, help='Model learning rate')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='Learning rate decay rate')
parser.add_argument('--lr_decay_interval', type=int, default=1500, help='Learning rate decay interval')

args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, _ = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    featnet_leye = FeatNet().to(device)
    featnet_reye = FeatNet().to(device)
    featnet_others = FeatNet().to(device)
    
    criterion = PWConLoss().to(device)
    optimizer = torch.optim.Adam([
        {'params': featnet_leye.parameters(), 'lr': args.learning_rate},
        {'params': featnet_reye.parameters(), 'lr': args.learning_rate},
        {'params': featnet_others.parameters(), 'lr': args.learning_rate}])
        
    for epoch in range(args.num_epochs):
        if epoch != 0 :
            scheduler_step(optimizer, epoch, args.lr_decay_interval, args.lr_decay_rate)
        
        loss_leye, loss_reye, loss_others = train(featnet_leye, featnet_reye, featnet_others, 
                                                  train_loader, criterion, optimizer)
        
        print("\nEpoch: {}/{}.. ".format(epoch+1, args.num_epochs).ljust(14),
              "PWConLoss_leye: {:.3f}.. ".format(loss_leye).ljust(14),
              "PWConLoss_reye: {:.3f}.. ".format(loss_reye).ljust(14),
              "PWConLoss_others: {:.3f}.. ".format(loss_others).ljust(14),) 
        
        if epoch % 500 == 499 : 
            save_model("FeatNet", featnet_leye, featnet_reye, featnet_others, optimizer, epoch+1)
        
        
        
def train(featnet_leye, featnet_reye, featnet_others, train_loader, criterion, optimizer):
    featnet_leye.train()
    featnet_reye.train()
    featnet_others.train()
    
    loss_leye = 0
    loss_reye = 0
    loss_others = 0
    
    leye_idx = [0, 2, 4, 5, 8, 10, 12, 13, 16]
    reye_idx = [1, 3, 6, 7, 9, 11, 14, 15, 17]
    
    for i, (images, landmark_coords) in enumerate(train_loader) :
        images, landmark_coords = images.to(device), landmark_coords.to(device)  # labels : [B,58], y-x-y-x order
        B = images.size(0)
        landmark_coords = landmark_coords.view(-1, 29, 2)
        leye_coords = landmark_coords[:, leye_idx]
        reye_coords = landmark_coords[:, reye_idx]
        others_coords = landmark_coords[:, 18:]
        
        featnet_leye.zero_grad()
        featnet_reye.zero_grad()
        featnet_others.zero_grad()
        optimizer.zero_grad()
        
        anchor_view, random_view, landmark_select, \
            a_l_distance, a_l_relationship, r_l_distance, r_l_relationship \
                = anchor_random_crop(images, leye_coords)
        anchor_random = torch.cat((anchor_view, random_view), 0)
        z = featnet_leye(anchor_random)
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, landmark_select, 
                              a_l_distance, a_l_relationship, 
                              r_l_distance, r_l_relationship)
        loss_leye += pwconloss.item() / len(train_loader)
        pwconloss.backward()
        
        
        anchor_view, random_view, landmark_select, \
            a_l_distance, a_l_relationship, r_l_distance, r_l_relationship \
                = anchor_random_crop(images, reye_coords)
        anchor_random = torch.cat((anchor_view, random_view), 0)
        z = featnet_reye(anchor_random)
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, landmark_select, 
                              a_l_distance, a_l_relationship, 
                              r_l_distance, r_l_relationship)
        loss_reye += pwconloss.item() / len(train_loader)
        pwconloss.backward()
        
        
        anchor_view, random_view, landmark_select, \
            a_l_distance, a_l_relationship, r_l_distance, r_l_relationship \
                = anchor_random_crop(images, others_coords)
        anchor_random = torch.cat((anchor_view, random_view), 0)
        z = featnet_others(anchor_random)
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, landmark_select, 
                              a_l_distance, a_l_relationship, 
                              r_l_distance, r_l_relationship)
        loss_others += pwconloss.item() / len(train_loader)
        pwconloss.backward()
        
        optimizer.step()
        
    return loss_leye, loss_reye, loss_others



if __name__=='__main__':
    main()
    
    
    
    
    
