from __future__ import print_function
import time
import torch
import argparse

from Utils import load_data
from Utils import load_state
from Utils import load_dirpolnet
from Utils import load_optuna_setting
from Utils import norm_coord_to_abs
from Utils import NME_calc
from Utils import NME_calc_landmarkwise

from Environment import Env

from Models import Agent
from Models import Agent_jaw
from Models import dirPolNet


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

parser = argparse.ArgumentParser(description='Landmark Detection with Active Inference')

parser.add_argument('--task', type=str, default='300W', help='which task to run (CelebA_aligned)')
parser.add_argument('--n_landmarks', type=int, default=68, help='the number of landmarks')
parser.add_argument('--log_interval', type=int, default=5, help='interval for log [batches]')

# Dataset preprocessing
parser.add_argument('--random_scale', default=False, help='Whether to apply random flip')
parser.add_argument('--random_flip', default=False, help='Whether to apply random flip')
parser.add_argument('--random_rotation', default=False, help='Whether to apply random rotation')

# Test setting
parser.add_argument('--batch_size', type=int, default=50, help='batch size for test')
parser.add_argument('--maximum_stage', type=int, default=2, help='Maximum detection stage for each landmark')
parser.add_argument('--max_timestep', type=int, default=[30, 30, 30, 30], help='maximum number of time-steps')
parser.add_argument('--center_coord_init', default=True, help='set initial coordinate to center')

args = parser.parse_args()


def main():
    torch.cuda.empty_cache()
    _, test_loader = load_data(args.task, args.batch_size, args.random_scale, args.random_flip, args.random_rotation)
    
    agent_leye = Agent(args.batch_size, args.maximum_stage).to(device)
    agent_reye = Agent(args.batch_size, args.maximum_stage).to(device)
    agent_mouth = Agent(args.batch_size, args.maximum_stage).to(device)
    agent_nose = Agent(args.batch_size, args.maximum_stage).to(device)
    agent_jaw = Agent_jaw(args.batch_size, args.maximum_stage).to(device)
    
    leye_state, reye_state, mouth_state, nose_state, jaw_state = load_state()
    
    dirpolnet_leye = dirPolNet().to(device)
    dirpolnet_reye = dirPolNet().to(device)
    dirpolnet_mouth = dirPolNet().to(device)
    dirpolnet_nose = dirPolNet().to(device)
    dirpolnet_jaw = dirPolNet().to(device)
    
    dirpolnet_leye, dirpolnet_reye, dirpolnet_mouth, dirpolnet_nose, dirpolnet_jaw \
        = load_dirpolnet(dirpolnet_leye, dirpolnet_reye, dirpolnet_mouth, dirpolnet_nose, dirpolnet_jaw)
    
    args.lambda_control_start, args.lambda_decrease,args. lambda_ft_init, args.lambda_freq,\
        args.thr_control_start, args.thr_increase, args.thr_init, args.thr_freq, \
            args.lambda_ft_1stage, args.lambda_ft_2stage = load_optuna_setting()
    
    NME_pupil_mean, NME_pupil_l_mean, detection_timesteps_total\
            = test(agent_leye, leye_state, dirpolnet_leye,
                   agent_reye, reye_state, dirpolnet_reye,
                   agent_mouth, mouth_state, dirpolnet_mouth,
                   agent_nose, nose_state, dirpolnet_nose, 
                   agent_jaw, jaw_state, dirpolnet_jaw, test_loader)
    
    detection_timesteps_mean = detection_timesteps_total.mean(1)
    
    print("NME_pupil_mean: {:.3f}.. ".format(NME_pupil_mean).ljust(15))
    
    print("NME_pupil_l_mean: \n{}.. \n".format(NME_pupil_l_mean).ljust(15))
    
    print("detection_timesteps_total_mean: \n{}.. \n".format(detection_timesteps_mean).ljust(15))


    
    
def test(agent_leye, leye_state, dirpolnet_leye,
         agent_reye, reye_state, dirpolnet_reye,
         agent_mouth, mouth_state, dirpolnet_mouth,
         agent_nose, nose_state, dirpolnet_nose, 
         agent_jaw, jaw_state, dirpolnet_jaw, test_loader) : 
    agent_leye.eval()
    agent_leye.eps = 0.
    dirpolnet_leye.eval()
    agent_reye.eval()
    agent_reye.eps = 0.
    dirpolnet_reye.eval()
    agent_mouth.eval()
    agent_mouth.eps = 0.
    dirpolnet_mouth.eval()
    agent_nose.eval()
    agent_nose.eps = 0.
    dirpolnet_nose.eval()
    agent_jaw.eval()
    agent_jaw.eps = 0.
    dirpolnet_jaw.eval()
    
    # for batch-wise
    NME_pupil_samples = []
    detect_fail_batch = 0
    
    NME_pupil_l_samples = []
    for l in range(args.n_landmarks) : 
        NME_pupil_l_samples.append([])
    detection_timesteps_total = torch.zeros(args.n_landmarks, args.batch_size).to(device)
    
    fps_list = []
    
    with torch.no_grad() : 
        for i, (images, tpts, pts, center, scale) in enumerate(test_loader) :         
            inferred_landmark_coords = torch.zeros(args.n_landmarks, args.batch_size, 2).to(device)
            
            images = images.to(device)
            env = Env(images)
            
            center_init = args.center_coord_init
            coord = None
            
            start_time = time.perf_counter()
            for stage in range(1, 2*args.maximum_stage+1) : 
                l_idxs = [39] if stage <=2 else [17, 18, 19, 20, 21, 36, 37, 38, 40, 41]
            
                env.coord_init(center_init, coord, len(l_idxs))
                o_t = env.c_to_o()
                    
                idx = 0 if stage % 2 == 1 else 1
                
                lambda_control_start = args.lambda_control_start[l_idxs][:, idx]
                lambda_ft = args.lambda_ft_1stage[l_idxs] if idx == 0 else args.lambda_ft_2stage[l_idxs]
                lambda_freq = args.lambda_freq[l_idxs][:, idx]
                thr_control_start = args.thr_control_start[l_idxs][:, idx]
                thr_increase = args.thr_increase[l_idxs][:, idx]
                thr_init = args.thr_init[l_idxs][:, idx]
                thr_freq = args.thr_freq[l_idxs][:, idx]
                
                agent_leye.set_hyperparams(thr_init, thr_control_start, thr_freq, thr_increase, 
                                           lambda_control_start, lambda_freq, lambda_ft)
                    
                agent_leye.set_prior(leye_state, stage)
                
                end_step = False
                
                for t in range(args.max_timestep[idx]) : 
                    if t == args.max_timestep[idx] - 1: 
                        end_step = True
                    
                    act_to_env = agent_leye.progress_detection(o_t, dirpolnet_leye, end_step)
                    
                    # detection complete for all samples
                    if act_to_env == None : 
                        detection_timesteps_total[l_idxs] += agent_leye.detection_time / len(test_loader)
                        inferred_leye_coords = agent_leye.landmark_coords  # [n_l, B, 2]
                        inferred_landmark_coords[l_idxs] = inferred_leye_coords
                        break
                    
                    env.apply_action(act_to_env)
                    o_t = env.current_o
                
                inferred_leye_coords_abs = norm_coord_to_abs(inferred_leye_coords, img_size=[256, 256]).long()
                center_init = False
                coord = inferred_leye_coords_abs
            
            
            # reye agent
            center_init = args.center_coord_init
            coord = None
            
            for stage in range(1, 2*args.maximum_stage+1) : 
                l_idxs = [42] if stage <=2 else [22, 23, 24, 25, 26, 43, 44, 45, 46, 47]
                
                env.coord_init(center_init, coord, len(l_idxs))
                o_t = env.c_to_o()
                    
                idx = 0 if stage % 2 == 1 else 1
                
                lambda_control_start = args.lambda_control_start[l_idxs][:, idx]
                lambda_ft = args.lambda_ft_1stage[l_idxs] if idx == 0 else args.lambda_ft_2stage[l_idxs]
                lambda_freq = args.lambda_freq[l_idxs][:, idx]
                thr_control_start = args.thr_control_start[l_idxs][:, idx]
                thr_increase = args.thr_increase[l_idxs][:, idx]
                thr_init = args.thr_init[l_idxs][:, idx]
                thr_freq = args.thr_freq[l_idxs][:, idx]
                
                agent_reye.set_hyperparams(thr_init, thr_control_start, thr_freq, thr_increase, 
                                           lambda_control_start, lambda_freq, lambda_ft)
                    
                agent_reye.set_prior(reye_state, stage)
                
                end_step = False
                
                for t in range(args.max_timestep[idx]) : 
                    if t == args.max_timestep[idx] - 1: 
                        end_step = True
                    
                    act_to_env = agent_reye.progress_detection(o_t, dirpolnet_reye, end_step)
                    
                    if act_to_env == None : # detection complete for all samples
                        detection_timesteps_total[l_idxs] += agent_reye.detection_time / len(test_loader)
                        inferred_reye_coords = agent_reye.landmark_coords  # [B, n_l, 2]
                        inferred_landmark_coords[l_idxs] = inferred_reye_coords
                        break
                    
                    env.apply_action(act_to_env)
                    o_t = env.current_o
                
                inferred_reye_coords_abs = norm_coord_to_abs(inferred_reye_coords, img_size=[256, 256]).long()
                center_init = False
                coord = inferred_reye_coords_abs
            
            
            # mouth agent
            center_init = args.center_coord_init
            coord = None
            
            for stage in range(1, 2*args.maximum_stage+1) : 
                l_idxs = [51] if stage <=2 else [48, 49, 50, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67]
                
                env.coord_init(center_init, coord, len(l_idxs))
                o_t = env.c_to_o()
                
                idx = 0 if stage % 2 == 1 else 1
                
                lambda_control_start = args.lambda_control_start[l_idxs][:, idx]
                lambda_ft = args.lambda_ft_1stage[l_idxs] if idx == 0 else args.lambda_ft_2stage[l_idxs]
                lambda_freq = args.lambda_freq[l_idxs][:, idx]
                thr_control_start = args.thr_control_start[l_idxs][:, idx]
                thr_increase = args.thr_increase[l_idxs][:, idx]
                thr_init = args.thr_init[l_idxs][:, idx]
                thr_freq = args.thr_freq[l_idxs][:, idx]
                
                agent_mouth.set_hyperparams(thr_init, thr_control_start, thr_freq, thr_increase, 
                                           lambda_control_start, lambda_freq, lambda_ft)
                
                agent_mouth.set_prior(mouth_state, stage)
                
                end_step = False
                
                for t in range(args.max_timestep[idx]) : 
                    if t == args.max_timestep[idx] - 1: 
                        end_step = True
                    
                    act_to_env = agent_mouth.progress_detection(o_t, dirpolnet_mouth, end_step)
                    
                    if act_to_env == None : # detection complete for all samples
                        detection_timesteps_total[l_idxs] += agent_mouth.detection_time / len(test_loader)
                        inferred_mouth_coords = agent_mouth.landmark_coords  # [B, n_l, 2]
                        inferred_landmark_coords[l_idxs] = inferred_mouth_coords
                        break
                    
                    env.apply_action(act_to_env)
                    o_t = env.current_o
                
                inferred_mouth_coords_abs = norm_coord_to_abs(inferred_mouth_coords, img_size=[256, 256]).long()
                center_init = False
                coord = inferred_mouth_coords_abs
            
            
            
            # nose agent
            center_init = args.center_coord_init
            coord = None
            
            for stage in range(1, 2*args.maximum_stage+1) : 
                l_idxs = [30] if stage <=2 else [27, 28, 29, 31, 32, 33, 34, 35]
                
                env.coord_init(center_init, coord, len(l_idxs))
                o_t = env.c_to_o()
                
                idx = 0 if stage % 2 == 1 else 1
                
                lambda_control_start = args.lambda_control_start[l_idxs][:, idx]
                lambda_ft = args.lambda_ft_1stage[l_idxs] if idx == 0 else args.lambda_ft_2stage[l_idxs]
                lambda_freq = args.lambda_freq[l_idxs][:, idx]
                thr_control_start = args.thr_control_start[l_idxs][:, idx]
                thr_increase = args.thr_increase[l_idxs][:, idx]
                thr_init = args.thr_init[l_idxs][:, idx]
                thr_freq = args.thr_freq[l_idxs][:, idx]
                
                agent_nose.set_hyperparams(thr_init, thr_control_start, thr_freq, thr_increase, 
                                           lambda_control_start, lambda_freq, lambda_ft)
                
                agent_nose.set_prior(nose_state, stage)
                
                end_step = False
                
                for t in range(args.max_timestep[idx]) : 
                    if t == args.max_timestep[idx] - 1: 
                        end_step = True
                    
                    act_to_env = agent_nose.progress_detection(o_t, dirpolnet_nose, end_step)
                    
                    if act_to_env == None : # detection complete for all samples
                        detection_timesteps_total[l_idxs] += agent_nose.detection_time / len(test_loader)
                        inferred_nose_coords = agent_nose.landmark_coords  # [B, n_l, 2]
                        inferred_landmark_coords[l_idxs] = inferred_nose_coords
                        break
                    
                    env.apply_action(act_to_env)
                    o_t = env.current_o
                
                inferred_nose_coords_abs = norm_coord_to_abs(inferred_nose_coords, img_size=[256, 256]).long()
                center_init = False
                coord = inferred_nose_coords_abs
            
            
            # jaw agent
            center_init = args.center_coord_init
            coord = None
            
            for stage in range(1, 3) : 
                l_idxs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
                
                env.coord_init(center_init, coord, len(l_idxs))
                o_t = env.c_to_o()
                
                idx = 0 if stage % 2 == 1 else 1
                
                lambda_control_start = args.lambda_control_start[l_idxs][:, idx]
                lambda_ft = args.lambda_ft_1stage[l_idxs] if idx == 0 else args.lambda_ft_2stage[l_idxs]
                lambda_freq = args.lambda_freq[l_idxs][:, idx]
                thr_control_start = args.thr_control_start[l_idxs][:, idx]
                thr_increase = args.thr_increase[l_idxs][:, idx]
                thr_init = args.thr_init[l_idxs][:, idx]
                thr_freq = args.thr_freq[l_idxs][:, idx]
                
                agent_jaw.set_hyperparams(thr_init, thr_control_start, thr_freq, thr_increase, 
                                           lambda_control_start, lambda_freq, lambda_ft)
                
                agent_jaw.set_prior(jaw_state, stage)
                
                end_step = False
                
                for t in range(args.max_timestep[idx]) : 
                    if t == args.max_timestep[idx] - 1: 
                        end_step = True
                    
                    act_to_env = agent_jaw.progress_detection(o_t, dirpolnet_jaw, end_step)
                    
                    if act_to_env == None : # detection complete for all samples
                        detection_timesteps_total[l_idxs] += agent_jaw.detection_time / len(test_loader)
                        inferred_jaw_coords = agent_jaw.landmark_coords  # [B, n_l, 2]
                        inferred_landmark_coords[l_idxs] = inferred_jaw_coords
                        break
                    
                    env.apply_action(act_to_env)
                    o_t = env.current_o
                
                inferred_jaw_coords_abs = norm_coord_to_abs(inferred_jaw_coords, img_size=[256, 256]).long()
                center_init = False
                coord = inferred_jaw_coords_abs
            
            end_time = time.perf_counter()
            elasped = (end_time - start_time)
            fps = images.size(0) / elasped
            fps_list.append(fps)
            
            inferred_landmark_coords = inferred_landmark_coords.permute(1,0,2)
            inferred_landmark_coords_abs = norm_coord_to_abs(inferred_landmark_coords, [256,256]).view(-1,68,2)
            inferred_landmark_coords_abs = inferred_landmark_coords_abs.flip(dims=[-1])
            NME_pupil = NME_calc(inferred_landmark_coords_abs, pts, center, scale)
            
            NME_pupil_samples.extend(NME_pupil.tolist())
            
            
            NME_pupil_l = NME_calc_landmarkwise(inferred_landmark_coords_abs, pts, center, scale)
            for l in range(args.n_landmarks) : 
                NME_pupil_l_samples[l].extend(NME_pupil_l[:, l].tolist())
            
            if i % args.log_interval == args.log_interval-1 : 
                print("Batches: {}/{}.. ".format(i+1, len(test_loader)).ljust(14),
                      "Detect_fail: {}.. ".format(detect_fail_batch))
    
    NME_pupil_mean = 100*torch.FloatTensor(NME_pupil_samples).mean()
    
    NME_pupil_l_mean = 100*torch.FloatTensor(NME_pupil_l_samples).mean(1)
    
    return NME_pupil_mean, NME_pupil_l_mean, detection_timesteps_total



if __name__=='__main__':
    main()
    