from __future__ import print_function
import torch
import torch.nn.functional as F
import argparse

from utils import load_data
from utils import load_state
from utils import make_batch
from utils import save_model

from environment import Env

from models import Agent
from models import dirPolNet


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='Landmark Detection with Active Inference')

parser.add_argument('--task', type=str, default='COFW', help='which task to run (CelebA_aligned)')

parser.add_argument('--random_scale', default=True, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=True, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

parser.add_argument('--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument('--num_epochs', type=int, default=1000, help='maximum number of epochs')
parser.add_argument('--learning_rate', type=float, default=5E-4, help='Model learning rate')


args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, _ = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    agent_leye = Agent(args.batch_size).to(device)
    agent_reye = Agent(args.batch_size).to(device)
    agent_mouth = Agent(args.batch_size).to(device)
    agent_nose = Agent(args.batch_size).to(device)
    agent_jaw = Agent(args.batch_size).to(device)
    
    leye_state, reye_state, mouth_state, nose_state, jaw_state = load_state()
    agent_leye.set_prior(leye_state)
    agent_reye.set_prior(reye_state)
    agent_mouth.set_prior(mouth_state)
    agent_nose.set_prior(nose_state)
    agent_jaw.set_prior(jaw_state)
    
    dirpolnet_leye = dirPolNet().to(device)
    dirpolnet_reye = dirPolNet().to(device)
    dirpolnet_mouth = dirPolNet().to(device)
    dirpolnet_nose = dirPolNet().to(device)
    dirpolnet_jaw = dirPolNet().to(device)
    
    optimizer = torch.optim.Adam([
        {'params': dirpolnet_leye.parameters(), 'lr': args.learning_rate},
        {'params': dirpolnet_reye.parameters(), 'lr': args.learning_rate},
        {'params': dirpolnet_mouth.parameters(), 'lr': args.learning_rate},
        {'params': dirpolnet_nose.parameters(), 'lr': args.learning_rate},
        {'params': dirpolnet_jaw.parameters(), 'lr': args.learning_rate}])
    
    
    for epoch in range(args.num_epochs):
        loss_leye, loss_reye, loss_mouth, loss_nose, loss_jaw\
            = train(agent_leye, agent_reye, agent_mouth, agent_nose, agent_jaw,
                    dirpolnet_leye, dirpolnet_reye, dirpolnet_mouth, dirpolnet_nose, dirpolnet_jaw,
                    train_loader, optimizer, epoch)
        
        print("Epoch: {}/{}.. ".format(epoch+1, args.num_epochs).ljust(14),
              "Loss_leye: {:.3f}.. ".format(loss_leye).ljust(14),
              "Loss_reye: {:.3f}.. ".format(loss_reye).ljust(14),
              "Loss_mouth: {:.3f}.. ".format(loss_mouth).ljust(14),
              "Loss_nose: {:.3f}.. ".format(loss_nose).ljust(14),
              "Loss_jaw: {:.3f}.. ".format(loss_jaw).ljust(14),)
        
        if epoch % 200 == 199 : 
            save_model("dirPolNet", dirpolnet_leye, dirpolnet_reye, dirpolnet_mouth,
                       dirpolnet_nose, dirpolnet_jaw, epoch+1)
        
        

def train(agent_leye, agent_reye, agent_mouth, agent_nose, agent_jaw, 
          dirpolnet_leye, dirpolnet_reye, dirpolnet_mouth, dirpolnet_nose, dirpolnet_jaw, 
          train_loader, optimizer, epoch) : 
    agent_leye.eval()
    agent_reye.eval()
    agent_mouth.eval()
    agent_nose.eval()
    agent_jaw.eval()
    
    dirpolnet_leye.train()
    dirpolnet_reye.train()
    dirpolnet_mouth.train()
    dirpolnet_nose.train()
    dirpolnet_jaw.train()
    
    loss_leye_total = 0.0
    loss_reye_total = 0.0
    loss_mouth_total = 0.0
    loss_nose_total = 0.0
    loss_jaw_total = 0.0
    
    for i, (images, landmark_coords) in enumerate(train_loader) : 
        images, landmark_coords = images.to(device), landmark_coords.to(device)
        landmark_coords = landmark_coords.view(-1, 68, 2)
        
        env = Env(images, center_init=False, coord=None)
        
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction \
            = make_batch(env, agent_leye, 11)
        prob_leye = dirpolnet_leye(embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, 1-lambda_ft)
        loss_leye = F.cross_entropy(prob_leye, direction)
        loss_leye.backward()
        loss_leye_total += loss_leye.item() / len(train_loader)
        
        
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction \
            = make_batch(env, agent_reye, 11)
        prob_reye = dirpolnet_reye(embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, 1-lambda_ft)
        loss_reye = F.cross_entropy(prob_reye, direction)
        loss_reye.backward()
        loss_reye_total += loss_reye.item() / len(train_loader)
        
        
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction \
            = make_batch(env, agent_mouth, 20)
        prob_mouth = dirpolnet_mouth(embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, 1-lambda_ft)
        loss_mouth = F.cross_entropy(prob_mouth, direction)
        loss_mouth.backward()
        loss_mouth_total += loss_mouth.item() / len(train_loader)
        
        
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction \
            = make_batch(env, agent_nose, 9)
        prob_nose = dirpolnet_nose(embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, 1-lambda_ft)
        loss_nose = F.cross_entropy(prob_nose, direction)
        loss_nose.backward()
        loss_nose_total += loss_nose.item() / len(train_loader)
        
        
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction \
            = make_batch(env, agent_jaw, 17)
        prob_jaw = dirpolnet_jaw(embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, 1-lambda_ft)
        loss_jaw = F.cross_entropy(prob_jaw, direction)
        loss_jaw.backward()
        loss_jaw_total += loss_jaw.item() / len(train_loader)
        
        optimizer.step()
        
    return loss_leye_total, loss_reye_total, loss_mouth_total, loss_nose_total, loss_jaw_total




if __name__=='__main__':
    main()
