from __future__ import print_function
import torch
import torch.nn.functional as F
import argparse

from utils import load_data
from utils import random_random_crop_reference_c
from utils import abs_coord_to_norm
from utils import norm_coord_to_abs
from utils import save_model
from utils import scheduler_step
from utils import random_o_crop

from network_and_loss import RelCoordNet
from network_and_loss import PWConLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float

parser = argparse.ArgumentParser(description='RelCoordNet training with PWConLoss')

# Hyperparameters 
parser.add_argument('--task', type=str, default="COFW", help='Task dataset')
parser.add_argument('--mode', type=str, default="all", help='Mode of SupConLoss')
parser.add_argument('--random_scale', default=True, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=True, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

parser.add_argument('--log_interval', type=int, default=100, help='interval for log [iteration]')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
parser.add_argument('--num_epochs', type=int, default=2000, help='Maximum number of epochs')
parser.add_argument('--learning_rate_proj', type=float, default=1E-2, help='Model learning rate')
parser.add_argument('--learning_rate_head', type=float, default=1E-3, help='Model learning rate')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='Learning rate decay rate')
parser.add_argument('--lr_decay_interval', type=int, default=1000, help='Learning rate decay interval')

args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, test_loader = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    relcoordnet = RelCoordNet().to(device)
    
    optimizer = torch.optim.Adam([
        {'params': list(relcoordnet.conv.parameters())+list(relcoordnet.proj.parameters()), 'lr': args.learning_rate_proj},
        {'params': relcoordnet.head.parameters(), 'lr': args.learning_rate_head} ])
    
    criterion = PWConLoss(args.mode).to(device)
    
    for epoch in range(args.num_epochs):
        if epoch != 0 :
            scheduler_step(optimizer, epoch, args.lr_decay_interval, args.lr_decay_rate)
        
        pwconloss, mseloss = train(relcoordnet, train_loader, criterion, optimizer)
        
        regression_error_mean, regression_error_std = test(relcoordnet, test_loader)
        
        print("\nEpoch: {}/{}..".format(epoch+1, args.num_epochs).ljust(14),
              "PWConloss: {:.3f}.. ".format(pwconloss).ljust(12),
              "MSEloss: {:.3f}.. ".format(mseloss).ljust(12))   
        print("Regression_error_mean: {:.3f}.. ".format(regression_error_mean).ljust(12),
              "Regression_error_std: {:.3f} .. ".format(regression_error_std).ljust(12))
        
        # save model pth file
        if epoch % 500 == 499 : 
            save_model("RelCoordNet", relcoordnet, optimizer, epoch+1)
        
        
        
def train(relcoordnet, train_loader, criterion, optimizer):
    relcoordnet.train()
    total_pwconloss = 0
    total_mseloss = 0
    
    for i, (images, landmark_coords) in enumerate(train_loader) :
        images, landmark_coords = images.to(device), landmark_coords.to(device)
        B = images.size(0)
        
        reference_coords1, reference_coords2, random_o1, random_coords1_relative, random_o2, random_coords2_relative,\
            r1_r1_distance, r1_r1_relationship, r1_r2_distance, r1_r2_relationship, r2_r2_distance, r2_r2_relationship\
                = random_random_crop_reference_c(images)
        
        reference_coords1_norm = abs_coord_to_norm(reference_coords1)
        random_coords1_norm = abs_coord_to_norm(random_coords1_relative + reference_coords1)
        reference_coords2_norm = abs_coord_to_norm(reference_coords2)
        random_coords2_norm = abs_coord_to_norm(random_coords2_relative + reference_coords2)
        random_coords1_norm_relative = random_coords1_norm - reference_coords1_norm
        random_coords2_norm_relative = random_coords2_norm - reference_coords2_norm
        label_coords_norm_relative = torch.cat((random_coords1_norm_relative, random_coords2_norm_relative), 0)
        
        random1_random2 = torch.cat((random_o1, random_o2), 0)
        
        reference_coords1_norm_channel = reference_coords1_norm.view(B, 2, 1, 1) * torch.ones(B, 2, 27, 27).to(device)
        reference_coords2_norm_channel = reference_coords2_norm.view(B, 2, 1, 1) * torch.ones(B, 2, 27, 27).to(device)
        reference_coords_input = torch.cat((reference_coords1_norm_channel, reference_coords2_norm_channel), dim=0)
        
        relcoordnet.zero_grad()
        optimizer.zero_grad()
        z, relative_c = relcoordnet(random1_random2, reference_coords_input)
        
        z1, z2 = torch.split(z, [B, B], dim=0)
        projections = torch.cat((z1.unsqueeze(1), z2.unsqueeze(1)), dim=1)
        pwconloss = criterion(projections, r1_r1_distance, r1_r1_relationship, \
                              r1_r2_distance, r1_r2_relationship, r2_r2_distance, r2_r2_relationship)
        
        mseloss = F.mse_loss(relative_c, label_coords_norm_relative)
        
        total_pwconloss += pwconloss.item() / len(train_loader)
        total_mseloss += mseloss.item() / len(train_loader)
        
        loss = pwconloss + 1000*mseloss
        loss.backward()
        optimizer.step()
        
    return total_pwconloss, total_mseloss



def test(relcoordnet, test_loader) : 
    with torch.no_grad() : 
        relcoordnet.eval()
        error = torch.zeros(1, 2).to(device)
        
        for i, (images, _) in enumerate(test_loader) : 
            images = images.to(device)
            B = images.size(0)
            
            reference_coords = torch.randint(0, 256, (B, 2)).to(device)
            reference_coords_norm = abs_coord_to_norm(reference_coords)
            reference_coords_input = reference_coords_norm.view(B, 2, 1, 1) * torch.ones(B, 2, 27, 27).to(device)
            
            o, crop_coords = random_o_crop(images, args.batch_size)
            
            _, relative_c = relcoordnet(o, reference_coords_input)
            c = relative_c + reference_coords_norm
            
            error_yx = torch.abs(norm_coord_to_abs(c) - crop_coords)
            error = torch.cat((error, error_yx), dim=0)
            
        error = error[1:]
        error = torch.sqrt(error[:, 0]**2 + error[:, 1]**2)
        error_mean = error.mean()
        error_std = error.std()
        
    return error_mean, error_std



if __name__=='__main__':
    main()
    
    
    
    
