from __future__ import division
import os
import time
import torch
import numpy as np
import torch.optim as optim
import argparse
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from setproctitle import setproctitle as ptitle

from model import build_model
from player_util import Agent
from environment import create_env

parser = argparse.ArgumentParser(description='A3C')
parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', help='learning rate (default: 0.0001)')
parser.add_argument('--gamma', type=float, default=0.9, metavar='G', help='discount factor for rewards (default: 0.99)')
parser.add_argument('--tau', type=float, default=1.00, metavar='T', help='parameter for GAE (default: 1.00)')
parser.add_argument('--entropy', type=float, default=0.01, metavar='T', help='parameter for entropy (default: 0.01)')
parser.add_argument('--grad-entropy', type=float, default=1.0, metavar='T', help='parameter for entropy (default: 0.01)')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--workers', type=int, default=1, metavar='W', help='how many training processes to use (default: 32)')
parser.add_argument('--num-steps', type=int, default=20, metavar='NS', help='number of forward steps in A3C (default: 300)')
parser.add_argument('--test-eps', type=int, default=1, metavar='M', help='testing episode length')
parser.add_argument('--env', default='simple', metavar='Pose-v0', help='environment to train on (default: Pose-v0|Pose-v1)')
parser.add_argument('--optimizer', default='Adam', metavar='OPT', help='shares optimizer choice of Adam or RMSprop')
parser.add_argument('--amsgrad', default=True, metavar='AM', help='Adam optimizer amsgrad parameter')
parser.add_argument('--load-coordinator-dir', default=None, metavar='LMD', help='folder to load trained models from')
parser.add_argument('--load-executor-dir', default=None, metavar='LMD', help='folder to load trained models from')
parser.add_argument('--log-dir', default='logs/', metavar='LG', help='folder to save logs')
parser.add_argument('--model', default='single', metavar='M', help='multi-shapleyV|')
parser.add_argument('--gpu-ids', type=int, default=-1, nargs='+', help='GPUs to use [-1 CPU only] (default: -1)')
parser.add_argument('--norm-reward', dest='norm_reward', action='store_true', help='normalize image')
parser.add_argument('--render', dest='render', action='store_true', help='render test')
parser.add_argument('--fix', dest='fix', action='store_true', help='fix random seed')
parser.add_argument('--shared-optimizer', dest='shared_optimizer', action='store_true', help='use an optimizer without shared statistics.')
parser.add_argument('--train-mode', type=int, default=-1, metavar='TM', help='his')
parser.add_argument('--input-size', type=int, default=80, metavar='IS', help='input image size')
parser.add_argument('--lstm-out', type=int, default=32, metavar='LO', help='lstm output size')
parser.add_argument('--sleep-time', type=int, default=0, metavar='LO', help='seconds')
parser.add_argument('--max-step', type=int, default=5000000, metavar='LO', help='max learning steps')
parser.add_argument('--render_save', dest='render_save', action='store_true', help='render save')
parser.add_argument('--env-steps', type=int, default=20, metavar='NS', help='number of steps in one env episode')

def sum_abs(x,y):
    return torch.sum(torch.abs(x-y))
def obs_wrapper(obs, num_agents, num_both):
    #obs num_agents * num_both * obs_dim
    obs_n = []
    for i in range(num_agents):
        obs_n.append(obs[i].reshape(-1))
    return obs_n
def get_comm_pairs_ToM(env):
    target_loc_n = []
    target_idx_n = []
    for i in range(env.n):
        target_loc = []
        target_idx = []
        loc_i = env.get_location(i)
        for j in range(env.n):
            if i == j:
                continue
            else:
                loc_j = env.get_location(j)
                delta_x = loc_i[0] - loc_j[0]
                delta_y = loc_i[1] - loc_j[1]
                target_loc.append(np.array([delta_x, delta_y]))
                target_idx.append(j)
        target_loc_n.append(target_loc)
        target_idx_n.append(target_idx)
    print(target_loc_n)
    print(target_idx_n)
    return target_loc_n, target_idx_n

def get_space(env):
    num_agents = env.n
    num_targets = env.num_target
    num_both = env.num_obstacle + num_targets
    obs_dim = env.state_dim

    obs_shape_n = [(num_both*obs_dim,) for _ in range(num_agents)]
    message_shape_n = [ (num_agents,num_both*obs_dim) for _ in range(num_agents)]
    target_loc_space_n = [(2,) for _ in range(num_agents)]
    return obs_shape_n, message_shape_n, target_loc_space_n

args = parser.parse_args()
env = create_env(args.env, args)
env.reset()
device_share = torch.device("cpu")
shared_model = build_model(env, args, device_share).to(device_share)
num_agents = env.n
num_targets = env.num_target
num_both = num_targets + env.num_obstacle
obs_dim = env.observation_space.shape[-1]
lstm_dim = args.lstm_out
batch = 3
state = torch.randn(batch,num_agents, num_both, obs_dim)
hself = torch.randn(batch,num_agents, lstm_dim)
hothers = torch.randn(batch,num_agents, num_agents-1, lstm_dim)
cam_states = torch.randn(batch,num_agents, 3)
values = []
entropies = []
next_hself = []
next_hothers = []
comm_logits = []
test = []

obs_shape_n, message_shape_n, target_loc_space_n = get_space(env)
print(obs_shape_n[0][0], message_shape_n, target_loc_space_n[0][0])


'''
def ToM_train(args, shared_model, optimizer_ToM, train_modes, n_iters, ToM_count, ToM_history, env=None):
    rank = args.workers
    n_iter = 0
    writer = SummaryWriter(os.path.join(args.log_dir, 'ToM'))
    ptitle('ToM Training')
    gpu_id = args.gpu_ids[rank % len(args.gpu_ids)]
    torch.manual_seed(args.seed + rank)
    env_name = args.env

    if gpu_id >= 0:
        torch.cuda.manual_seed(args.seed + rank)
        device = torch.device('cuda:' + str(gpu_id))
        if len(args.gpu_ids) > 1:
            device_share = torch.device('cpu')
        else:
            device_share = torch.device('cuda:' + str(args.gpu_ids[-1]))
    else:
        device_share = torch.device('cpu')
    #device_share = torch.device('cuda:0')
    if env == None:
        env = create_env(env_name, args)

    # # prepare model
    # model = build_model(env, args, device)
    # model = model.to(device)
    # model.train()

    params = []
    for name,param in shared_model.named_parameters():
        if 'ToM' in name or 'other' in name:
            params.append(param)

    ToM_len = args.ToM_frozen * args.workers * args.num_steps

    while True:
        t1 = time.time()
        while True:
            curr_time = time.time()
            if curr_time - t1 > 180:
                print("waiting too long for enough ToM")
                print("ToM_count:",ToM_count)
                return
            if sum(ToM_count) >= ToM_len:
                for rank in range(args.workers):
                    # inform workers to stop collecting trajectories
                    train_modes[rank] = -10
                break
            else:
                pass # wait for workers collecting ToM
        t2 = time.time()
        print("collected enough ToM trajectories after {} seconds".format(t2-t1))

        while True:
            flag = True
            # flag means whether all the workers have stopped to wait for ToM training
            for rank in range(args.workers):
                if train_modes[rank] != -20:
                    flag = False
            if flag:
                break
            curr_time = time.time()
            if curr_time - t2 > 180:
                print("waiting too long for the workers to stop")
                print("train_modes:",train_modes)
                return

        t3 = time.time()
        print("ToM training start after waiting for workers to stop for {} seconds".format(t3-t2))
        #print(train_modes)
        print("ToM_count:",ToM_count)

        state, cam_states = load_data(args, ToM_history)
        print("ToM data loaded")
        ToM_loss = optimize_ToM(state, cam_states, args, params, optimizer_ToM, shared_model, device_share, env)
        print("optimized based on ToM loss")
        n_steps = sum(n_iters) * args.num_steps
        writer.add_scalar('train/ToM_loss_sum', ToM_loss.sum(), n_steps)

        # with open(args.ToM_file,'r') as f:
        #     ToM_list = f.read().split('\n')[:-1]
        #     ToM_loss = optimize_ToM(ToM_list, params, optimizer_ToM, shared_model, device_share, env)
        #     # for ToM in ToM_list:
        #     #     ToM = json.loads(ToM)
        #     #     state = ToM['state']
        #     #     state = torch.from_numpy(np.array(state)).float().to(device)
        #     #     hself = ToM['hself']
        #     #     hself = torch.from_numpy(np.array(hself)).float().to(device)
        #     #     hothers = ToM['hothers']
        #     #     hothers = torch.from_numpy(np.array(hothers)).float().to(device)
        #     #     cam_states = ToM['cam_states']
        #     #     cam_states = torch.from_numpy(np.array(cam_states)).float().to(device)
        #     #     real_goals = ToM['real']
        #     #     real_goals = torch.from_numpy(np.array(real_goals)).float().to(device)
        #     n_steps = sum(n_iters) * args.num_steps
        #     writer.add_scalar('train/ToM_loss_sum', ToM_loss.sum(), n_steps)
        print("ToM training finished")
        print("---------------------")
        for rank in range(args.workers):
            ToM_count[rank] = 0
            ToM_history[rank] = []
            train_modes[rank] = -1

        if train_modes[0] == -100:
            env.close()
            break
'''