import os
import time
from datetime import datetime
import torch
import numpy as np
import torch.optim as optim
import math
import argparse
import logging
from tensorboardX import SummaryWriter

from environment import create_env
from utils import setup_logger

def next_choice(agent_num, total, curr):
    next_array = np.zeros(agent_num)
    up = 1
    for i in reversed(range(agent_num)):
        next_array[i] = curr[i] + up
        if next_array[i] >= total:
            up = 1
            next_array[i] = 0
        else:
            up = 0 
    return next_array

def find_best(env, cam_info, mode):
    agent_num = env.n
    total = len(env.discrete_actions)
    cnt = int(total/2)
    choice_num = int(math.pow(total, agent_num))
    curr_choice = np.zeros(agent_num)
    #curr_choices = [np.zeros(agent_num) for i in range(keep)]
    max_reward = -10
    min_angle_sum = agent_num * 180
    best_info = None
    if mode == 3:
        # random action
        tmp_info = []
        for info in cam_info:
            loc_x, loc_y = info[0]
            angle = info[1][0]
            tmp_info.append([[loc_x,loc_y],[angle]])
        for i in range(agent_num):
            tmp_info[i][1][0] += env.rotation_scale * np.random.randint(-1,2)
        reward, angle_sum = env.simplified_multi_reward(tmp_info)
        return tmp_info, reward
    for choice in range(choice_num):
        tmp_info = []
        for info in cam_info:
            loc_x, loc_y = info[0]
            angle = info[1][0]
            tmp_info.append([[loc_x,loc_y],[angle]])
        for i in range(agent_num):
                tmp_info[i][1][0] += int(curr_choice[i]-cnt) * env.rotation_scale
        reward, angle_sum = env.simplified_multi_reward(tmp_info)
        #print(reward,angle_sum)
        if mode == 0:
            better = (reward > max_reward or angle_sum < min_angle_sum)
        elif mode == 1:
            better = reward > max_reward
        elif mode == 2:
            better = angle_sum < min_angle_sum
        else:
            print("invalid mode")
            return 
        if better:
            max_reward = reward
            min_angle_sum = angle_sum
            best_info = tmp_info
        curr_choice = next_choice(agent_num, total, curr_choice)

    return best_info, max_reward

def greedy_action(env):
    cam_num = env.n 
    for i in range(cam_num):
        loc,rot = env.get_cam_info()

def get_cam_info(env):
    cam_info = []
    for i in range(env.n):
        cam_loc = env.get_location(i)
        cam_rot = env.get_rotation(i)
        cam_info.append([cam_loc, cam_rot])
    return cam_info


parser = argparse.ArgumentParser(description='optimal')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--workers', type=int, default=1, metavar='W', help='how many training processes to use (default: 32)')
parser.add_argument('--env', default='Pose-v3', metavar='Pose-v0', help='environment to train on (default: Pose-v3)')
parser.add_argument('--keep', type=int, default=10, metavar='W', help='low level')
parser.add_argument('--test-loops', type=int, default=10, metavar='M', help='testing episode length')
parser.add_argument('--test-eps', type=int, default=100, metavar='M', help='testing episode length')
parser.add_argument('--env-steps', type=int, default=100, metavar='NS', help='number of steps in one env episode')
parser.add_argument('--render_save', dest='render_save', action='store_true', help='render save')
parser.add_argument('--log-dir', default='logs/optimal', metavar='LG', help='folder to save logs')
parser.add_argument('--mode', type=int, default='3', metavar='0', help='best action mode, 0 for both, 1 for reward, 2 for distance , 3 for random')



args = parser.parse_args()
curr_time = datetime.now().strftime('%b%d_%H-%M')
args.logdir = os.path.join(args.log_dir, "mode-"+str(args.mode), curr_time)
env = create_env(args.env, args)
env.max_steps = args.env_steps
writer = SummaryWriter(args.logdir)

log = {}
setup_logger('{}_log'.format(args.env),
                r'{0}/logger'.format(args.log_dir))
log['{}_log'.format(args.env)] = logging.getLogger(
    '{}_log'.format(args.env))
d_args = vars(args)
for k in d_args.keys():
    log['{}_log'.format(args.env)].info('{0}: {1}'.format(k, d_args[k]))

reward_episode_average = 0
ave_reward_list = []
start_time = time.time()

for loop in range(args.test_loops):
    reward_sum = 0
    for episode in range(args.test_eps):
        env.reset()
        reward_eps = 0
        for step in range(args.env_steps):
            reward = 0
            '''
            target_pos_list = []
            for low_step in range(args.keep):
                target_pos_list.append(env.target_pos_list)
                env.target_move()
            '''
            #cam_info = get_cam_info(env)
            #best_info, gr = find_best(env, cam_info, target_pos_list, args.keep)
            
            for low_step in range(args.keep):
                env.target_move()
                cam_info = get_cam_info(env)
                best_info, gr = find_best(env, cam_info, args.mode)
                reward += gr
            
                # update camera rotation for the env
                for i in range(env.n):
                    loc,rot = cam_info[i]
                    rot[0] = best_info[i][1][0]
                    env.set_rotation(i, rot)

            reward = reward/args.keep
            #print("step:", step+1, reward, reward_sum/(step+1))
            
            reward_eps += reward
        reward_sum += reward_eps
        writer.add_scalar('reward', reward_eps, (episode + 1 + loop * args.test_eps)*args.env_steps)
        log['{}_log'.format(args.env)].info(
            "Time {0}, episode {1}, eps reward {2}".
            format(
                time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)),
                episode,
                reward_eps)
            )
        
    ave_reward = reward_sum/args.test_eps
    ave_reward_list.append(ave_reward)
    log['{}_log'.format(args.env)].info(
            "Time {0}, Loop {1}, ave eps reward {2}".
            format(
                time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - start_time)),
                loop,
                ave_reward)
            )

reward_mean = np.mean(ave_reward_list)
reward_std = np.std(ave_reward_list)
log['{}_log'.format(args.env)].info(
            "reward mean {0}, reward std {1}".
            format(
                reward_mean,
                reward_std)
            )
#reward_episode_average = reward_episode_average/args.test_eps
#print("total average episode reward:",reward_episode_average)

