from __future__ import print_function
import torch
import torch.nn.functional as F
import argparse

from utils import load_data
from utils import load_state
from utils import make_batch
from utils import save_model
from utils import abs_coord_to_norm

from environment import Env

from models import Agent
from models import Agent_relative
from models import dirPolNet


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='Landmark Detection with Active Inference')

parser.add_argument('--task', type=str, default='COFW', help='which task to run (CelebA_aligned)')

parser.add_argument('--random_scale', default=True, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=True, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

parser.add_argument('--batch_size', type=int, default=64, help='batch size for training')
parser.add_argument('--num_epochs', type=int, default=1000, help='maximum number of epochs')
parser.add_argument('--learning_rate', type=float, default=5E-4, help='Model learning rate')


args = parser.parse_args()

def main():
    torch.cuda.empty_cache()
    train_loader, _ = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    agent_leye = Agent(args.batch_size).to(device)
    agent_reye = Agent(args.batch_size).to(device)
    agent_others = Agent_relative(args.batch_size).to(device)
    
    leye_state, reye_state, others_state = load_state()
    agent_leye.set_prior(leye_state)
    agent_reye.set_prior(reye_state)
    agent_others.set_prior(others_state)
    
    dirpolnet_leye = dirPolNet().to(device)
    dirpolnet_reye = dirPolNet().to(device)
    dirpolnet_others = dirPolNet().to(device)
    
    optimizer = torch.optim.Adam([
        {'params': dirpolnet_leye.parameters(), 'lr': args.learning_rate},
        {'params': dirpolnet_reye.parameters(), 'lr': args.learning_rate},
        {'params': dirpolnet_others.parameters(), 'lr': args.learning_rate}])
    
    
    for epoch in range(args.num_epochs):
        loss_leye, loss_reye, loss_others = train(agent_leye, agent_reye, agent_others,
                                                  dirpolnet_leye, dirpolnet_reye, dirpolnet_others,
                                                  train_loader, optimizer, epoch)
        
        print("Epoch: {}/{}.. ".format(epoch+1, args.num_epochs).ljust(14),
              "Loss_leye: {:.3f}.. ".format(loss_leye).ljust(14),
              "Loss_reye: {:.3f}.. ".format(loss_reye).ljust(14),
              "Loss_others: {:.3f}.. ".format(loss_others).ljust(14))
        
        if epoch % 200 == 199 : 
            save_model("dirPolNet", dirpolnet_leye, dirpolnet_reye, dirpolnet_others, epoch+1)
        
        

def train(agent_leye, agent_reye, agent_others, 
          dirpolnet_leye, dirpolnet_reye, dirpolnet_others, 
          train_loader, optimizer, epoch) : 
    agent_leye.eval()
    agent_reye.eval()
    agent_others.eval()
    
    dirpolnet_leye.train()
    dirpolnet_reye.train()
    dirpolnet_others.train()
    
    loss_leye_total = 0.0
    loss_reye_total = 0.0
    loss_others_total = 0.0
    
    for i, (images, landmark_coords) in enumerate(train_loader) : 
        images, landmark_coords = images.to(device), landmark_coords.to(device)
        landmark_coords = landmark_coords.view(-1, 29, 2)
        
        env = Env(images, center_init=False, coord=None)
        
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction \
            = make_batch(env, agent_leye, 9)
        prob_leye = dirpolnet_leye(embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, 1-lambda_ft)
        loss_leye = F.cross_entropy(prob_leye, direction)
        loss_leye.backward()
        loss_leye_total += loss_leye.item() / len(train_loader)
        
        
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction \
            = make_batch(env, agent_reye, 9)
        prob_reye = dirpolnet_reye(embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, 1-lambda_ft)
        loss_reye = F.cross_entropy(prob_reye, direction)
        loss_reye.backward()
        loss_reye_total += loss_reye.item() / len(train_loader)
        
        reference_c = abs_coord_to_norm(landmark_coords[:, 21], img_size=[256, 256])
        embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, direction \
            = make_batch(env, agent_others, 11, reference_c)
        prob_others = dirpolnet_others(embedding_ft0, embedding_cd0, prior_ft, prior_cd, lambda_ft, 1-lambda_ft)
        loss_others = F.cross_entropy(prob_others, direction)
        loss_others.backward()
        loss_others_total += loss_others.item() / len(train_loader)
        
        optimizer.step()
        
    return loss_leye_total, loss_reye_total, loss_others_total




if __name__=='__main__':
    main()
