from __future__ import division
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from tensorboardX import SummaryWriter
from setproctitle import setproctitle as ptitle

import json
from model import build_model
from player_util import Agent
from environment import create_env

class HLoss(nn.Module):
    def __init__(self):
        super(HLoss, self).__init__()

    def forward(self, x, prior=None):
        if prior is None:
            b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
            b = -b.sum(1)
            b = b.mean()
        else:
            b = F.softmax(x, dim = -1)
            b = b * (F.log_softmax(x, dim = -1) - torch.log(prior).view(-1, x.size(-1)))
            b = -b.sum(-1)
            b = b.mean()
        return b

def optimize_ToM(state, cam_states, comm_domains, available_actions, args, params, optimizer_ToM, shared_model, device_share, env):
    num_agents = env.n
    num_targets = env.num_target
    max_steps = env.max_steps
    seg_num = int(max_steps/args.num_steps)
    batch_size, num_agents, num_both, obs_dim = state.size()
    count = int(batch_size/max_steps)
    print("batch_size = ",batch_size)
    # state, cam_states are only to device when being used
    state = state.reshape(count, max_steps, num_agents, num_both, obs_dim)#.to(device_share)
    batch_size, num_agents, cam_dim = cam_states.size()
    cam_states = cam_states.reshape(count, max_steps, num_agents, cam_dim)#.to(device_share)
    if 'ToM-v5' in args.model:
        comm_domains = comm_domains.reshape(count, max_steps, num_agents, num_agents, 1)
        h_ToM = torch.zeros(count, num_agents, num_agents, args.lstm_out).to(device_share)
    else:
        h_ToM = torch.zeros(count, num_agents, num_agents-1, args.lstm_out).to(device_share)
    hself = torch.zeros(count, num_agents, args.lstm_out ).to(device_share)
    #hothers = torch.zeros(count, num_agents, num_agents-1, args.lstm_out).to(device_share)
    hself_start = hself.clone().detach() # save the intial hidden state for every args.num_steps
    hToM_start = h_ToM.clone().detach()

    if args.mask_actions:
        available_actions = available_actions.reshape(count, max_steps, num_agents, num_targets, -1)

    ToM_goals = None
    real_goals = None
    ToM_loss_sum = torch.zeros(1).to(device_share)
    ToM_target_loss_sum = torch.zeros(1).to(device_share)
    ToM_target_acc_sum = torch.zeros(1).to(device_share)
    for seg in range(seg_num):
        for train_loop in range(args.ToM_train_loops):
            hself = hself_start.clone().detach()
            h_ToM = hToM_start.clone().detach()
            ToM_goals = None
            real_goals = None
            BCE_criterion = torch.nn.BCELoss(reduction='sum')
            ToM_target_loss = torch.zeros(1).to(device_share)
            ToM_target_acc = torch.zeros(1).to(device_share)
            for s_i in range(args.num_steps):
                step = seg * args.num_steps + s_i
                available_action = available_actions[:,step].to(device_share) if args.mask_actions else None
                if 'ToM-v2' in args.model:
                    value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goal, ToM_target_cover, real_cover, probs = \
                        shared_model(state[:,step].to(device_share), hself, h_ToM, cam_states[:,step].to(device_share))
                    ToM_target_loss += BCE_criterion(ToM_target_cover.float(), real_cover.float())
                    ToM_target_cover_discrete = (ToM_target_cover > 0.6)
                    ToM_target_acc += torch.sum((ToM_target_cover_discrete == real_cover))
                elif 'ToM-v4' in args.model:
                    value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goal, edge_logits, comm_edges, probs =\
                            shared_model(state[:,step].to(device_share), hself, h_ToM, cam_states[:,step].to(device_share), available_actions = available_action)
                elif 'ToM-v5' in args.model:
                    value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goal, edge_logits, comm_edges, probs, real_cover, ToM_target_cover =\
                            shared_model(state[:,step].to(device_share), hself, h_ToM, cam_states[:,step].to(device_share), comm_domains[:,step].to(device_share), available_actions = available_action)
                    ToM_target_loss += BCE_criterion(ToM_target_cover.float(), real_cover.float())
                    ToM_target_cover_discrete = (ToM_target_cover > 0.6)
                    ToM_target_acc += torch.sum((ToM_target_cover_discrete == real_cover))
                hself = hn_self
                h_ToM = hn_ToM

                ToM_goal = ToM_goal.unsqueeze(1)
                real_goal = torch.cat((1-actions,actions),-1).detach()
                real_goal_duplicate = real_goal.reshape(count, 1, num_agents, num_targets, -1).repeat(1, num_agents, 1, 1, 1)
                idx= (torch.ones(num_agents, num_agents) - torch.diag(torch.ones(num_agents))).bool()
                real_goal_duplicate = real_goal_duplicate[:,idx].reshape(count, 1, num_agents, num_agents-1, num_targets, -1)

                if ToM_goals is None:
                    ToM_goals = ToM_goal
                    real_goals = real_goal_duplicate
                else:
                    ToM_goals = torch.cat((ToM_goals, ToM_goal),1)
                    real_goals = torch.cat((real_goals, real_goal_duplicate), 1)
            ToM_loss = torch.zeros(1).to(device_share)
            KL_criterion = torch.nn.KLDivLoss(reduction='sum')
            real_prob = real_goals.float()
            ToM_prob = ToM_goals.float()
            ToM_loss += KL_criterion(ToM_prob.log(), real_prob)
            
            loss = ToM_loss + 0.5 * ToM_target_loss
            loss = loss/(count)
            shared_model.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, 20)
            optimizer_ToM.step()
            
        # update hidden state start & loss sum
        hself_start = hself.clone().detach()
        hToM_start = h_ToM.clone().detach()
        ToM_loss_sum += ToM_loss
        ToM_target_loss_sum += ToM_target_loss
        ToM_target_acc_sum += ToM_target_acc

    print("ToM_loss =",ToM_loss_sum.sum().data)
    print("ToM Target loss=",ToM_target_loss_sum.sum().data)
    cnt_all = (num_agents * (num_agents-1) * num_targets * batch_size)
    ToM_loss_mean = ToM_loss_sum/cnt_all
    ToM_target_loss_mean = ToM_target_loss_sum/cnt_all
    ToM_target_acc_mean = ToM_target_acc_sum/cnt_all
    return ToM_loss_sum, ToM_loss_mean, ToM_target_loss_mean, ToM_target_acc_mean  

def optimize_Policy(test_data, state, cam_states, real_actions, reward, comm_domains, available_actions, args, params, optimizer_Policy, shared_model, device_share, env):
    num_agents = env.n
    num_targets = env.num_target
    max_steps = env.max_steps
    assert max_steps % args.num_steps == 0
    seg_num = int(max_steps/args.num_steps)
    batch_size, num_agents, num_both, obs_dim = state.size()
    count = int(batch_size/max_steps)

    if count != args.workers:
        print(count)
    assert count == args.workers
    
    # state, cam_state, reward, real_actions are to state only when being used
    state = state.reshape(count, max_steps, num_agents, num_both, obs_dim)#.to(device_share)
    batch_size, num_agents, cam_dim = cam_states.size()
    cam_states = cam_states.reshape(count, max_steps, num_agents, cam_dim)#.to(device_share)
    batch_size, num_agents, r_dim = reward.size()
    reward = reward.reshape(count, max_steps, num_agents, r_dim)#.to(device_share)

    real_actions = real_actions.reshape(count, max_steps, num_agents, num_targets, 1)#.to(device_share)
    if 'ToM-v5' in args.model:
        comm_domains = comm_domains.reshape(count, max_steps, num_agents, num_agents, 1)
        h_ToM = torch.zeros(count, num_agents, num_agents, args.lstm_out).to(device_share)
    else:
        h_ToM = torch.zeros(count, num_agents, num_agents-1, args.lstm_out).to(device_share)
        comm_domains = None
    
    hself = torch.zeros(count, num_agents, args.lstm_out ).to(device_share)
    #hothers = torch.zeros(count, num_agents, num_agents-1, args.lstm_out).to(device_share)
    hself_start = hself.clone().detach() # save the intial hidden state for every args.num_steps
    hToM_start = h_ToM.clone().detach()
    if args.mask_actions:
        available_actions = available_actions.reshape(count, max_steps, num_agents, num_targets, -1)

    values = []
    entropies = []
    log_probs = []
    rewards = []
    edge_logits = []
    
    # for evaluation
    step_data = test_data[0]
    loss_data = test_data[1]
    '''
    #real_actions = step_data[0].reshape(count, args.num_steps, num_agents, -1)
    real_hself = step_data[1].reshape(count, args.num_steps, num_agents, -1)
    real_hothers = step_data[2].reshape(count, args.num_steps, num_agents, num_agents-1, -1)
    real_entropy = step_data[3].reshape(count, args.num_steps, num_agents, num_targets, -1)
    real_comm = step_data[4].reshape(count, args.num_steps, num_agents, num_agents, -1)
    real_log_probs = step_data[5].reshape(count, args.num_steps, num_agents, num_targets, -1)
    real_policy_loss = loss_data[0]
    real_value_loss = loss_data[1]
    real_entropy_sum = loss_data[2]
    real_sparsity_loss = loss_data[3]
    # end of evaluation
    '''
    policy_loss_sum = torch.zeros(count, num_agents, num_targets, 1).to(device_share)
    value_loss_sum = torch.zeros(count, num_agents, 1).to(device_share)
    entropies_all = torch.zeros(1).to(device_share)
    Sparsity_loss_sum = torch.zeros(count, 1).to(device_share)

    for seg in range(seg_num):
        for train_loop in range(args.policy_train_loops):
            hself = hself_start.clone().detach()
            h_ToM = hToM_start.clone().detach()
            values = []
            entropies = []
            log_probs = []
            rewards = []
            edge_logits = []
            for s_i in range(args.num_steps):
                step = s_i + seg * args.num_steps
                available_action = available_actions[:,step].to(device_share) if args.mask_actions else None
                if 'ToM-v2' in args.model:
                    value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goal, ToM_target, real_cover, probs =\
                        shared_model(state[:,step].to(device_share), hself, h_ToM, cam_states[:,step].to(device_share))
                    hself = hn_self
                    h_ToM = hn_ToM
                elif 'comm' in args.model:
                    value_multi, actions, entropy, log_prob, hn_self, edge_logit, comm_edges, probs =\
                        shared_model(state[:,step].to(device_share), hself)
                    hself = hn_self
                elif 'ToM-v4' in args.model:
                    value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goal, edge_logit, comm_edges, probs =\
                            shared_model(state[:,step].to(device_share), hself, h_ToM, cam_states[:,step].to(device_share), available_actions= available_action)
                    hself = hn_self
                    h_ToM = hn_ToM
                elif 'ToM-v5' in args.model:
                    value_multi, actions, entropy, log_prob, hn_self, hn_ToM, ToM_goal, edge_logit, comm_edges, probs, real_cover, ToM_target_cover =\
                            shared_model(state[:,step].to(device_share), hself, h_ToM, cam_states[:,step].to(device_share), comm_domains[:,step].to(device_share), available_actions= available_action)
                    hself = hn_self
                    hToM = hn_ToM        
                elif 'decentralized' in args.model:
                    value_multi, actions, entropy, log_prob, probs = shared_model(state[:,step].to(device_share), available_actions= available_action)
                elif 'center' in args.model:
                    value_multi, actions, entropy, log_prob, probs = shared_model(state[:,step].to(device_share), cam_states[:,step].to(device_share), available_actions= available_action)     

                values.append(value_multi)
                entropies.append(entropy)
                log_probs.append(torch.log(probs).gather(-1, real_actions[:,step].to(device_share)))
                rewards.append(reward[:,step].to(device_share))

                if ('ToM' in args.model and 'v2' not in args.model) or 'comm' in args.model:
                    edge_logits.append(edge_logit)

            R = torch.zeros(count, num_agents, 1).to(device_share)
            if seg < seg_num -1:
                # not the last segment of the episode
                next_step = (seg+1) * args.num_steps
                available_action = available_actions[:,next_step].to(device_share) if args.mask_actions else None
                if 'ToM-v2' in args.model:
                    value_multi, *others = shared_model(state[:,next_step].to(device_share), hself, h_ToM, cam_states[:,next_step].to(device_share))
                elif 'comm' in args.model:
                    value_multi, *others = shared_model(state[:,next_step].to(device_share), hself)
                elif 'ToM-v4' in args.model:
                    value_multi, *others = shared_model(state[:,next_step].to(device_share), hself, h_ToM, cam_states[:,next_step].to(device_share), available_actions= available_action)
                elif 'ToM-v5' in args.model:
                    value_multi, *others = shared_model(state[:,next_step].to(device_share), hself, h_ToM, cam_states[:,next_step].to(device_share), comm_domains[:,next_step].to(device_share), available_actions= available_action)
                elif 'decentralized' in args.model:
                    value_multi, *others = shared_model(state[:,next_step].to(device_share), available_actions= available_action)
                elif 'center' in args.model:
                    value_multi, *others = shared_model(state[:,next_step].to(device_share), cam_states[:,next_step].to(device_share), available_actions= available_action)
                R = value_multi.clone().detach()

            R = R.to(device_share)
            values.append(R)

            policy_loss = torch.zeros(count, num_agents, num_targets, 1).to(device_share)
            value_loss = torch.zeros(count, num_agents, 1).to(device_share)
            entropies_sum = torch.zeros(1).to(device_share)
            w_entropies = float(args.entropy)

            Sparsity_loss = torch.zeros(count, 1).to(device_share)

            KL_criterion = torch.nn.KLDivLoss(reduction='sum')
            #KL_single = torch.nn.KLDivLoss(reduction='none')
            BCE_criterion = torch.nn.BCELoss(reduction='sum')
            criterionH = HLoss()
            edge_prior = torch.FloatTensor(np.array([0.7, 0.3])).to(device_share)
            gae = torch.zeros(count, num_agents, 1).to(device_share)

            for i in reversed(range(args.num_steps)):
                R = args.gamma * R + rewards[i]
                advantage = R - values[i]
                value_loss = value_loss + 0.5 * advantage.pow(2)
                # Generalized Advantage Estimataion
                delta_t = rewards[i] + args.gamma * values[i + 1].data - values[i].data
                gae = gae * args.gamma * args.tau + delta_t
                #value_loss = value_loss + 0.5 * (gae + values[i].data -values[i]).pow(2)

                gae_duplicate = gae.unsqueeze(2).repeat(1,1,num_targets,1)
                policy_loss = policy_loss - (w_entropies * entropies[i]) - (log_probs[i] * gae_duplicate)
                entropies_sum += entropies[i].sum()
                #print(i,entropies[i].sum())

                if ('ToM' in args.model and 'v2' not in args.model) or 'comm' in args.model:
                    edge_logit = edge_logits[i]#.reshape(count * num_agents * num_agents, -1)  # k * 2
                    Sparsity_loss += -criterionH(edge_logit, edge_prior)
            
            shared_model.zero_grad()
            loss = policy_loss.sum() + 0.5 * value_loss.sum() #+ 0.3 * Sparsity_loss.sum()
            loss = loss/(count * 4)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(params, 5)
            optimizer_Policy.step()

        # update hself & hothers start for next segment
        hself_start = hself.clone().detach()
        hToM_start = h_ToM.clone().detach()
        # sum all the loss
        policy_loss_sum += policy_loss
        value_loss_sum += value_loss
        Sparsity_loss_sum += Sparsity_loss
        entropies_all += entropies_sum
    
    '''
    print(policy_loss_sum.sum() - real_policy_loss.sum())
    print(value_loss_sum.sum() - real_value_loss.sum())
    print(Sparsity_loss_sum.sum()- real_sparsity_loss.sum())
    print(entropies_all.sum() - real_entropy_sum.sum())
    '''
    return policy_loss_sum, value_loss_sum, Sparsity_loss_sum, entropies_all

def load_ToM_data(args, ToM_history):
    n = args.workers
    ToM_list = []
    # ToM_history is a list of list of dict
    for rank in range(n):
        assert  len(ToM_history[rank]) % args.num_steps == 0
        ToM_list += ToM_history[rank]
    
    state = []
    cam_states = []
    for history in ToM_list:
        state.append(history['state'])
        cam_states.append(history['cam_states'])
    state = torch.from_numpy(np.array(state))
    cam_states = torch.from_numpy(np.array(cam_states))
    return state, cam_states

def load_Policy_data(args, Policy_history):
    Policy_list = []
    for rank in range(args.workers):
        assert len(Policy_history[rank]) % args.num_steps == 0

        Policy_list += Policy_history[rank]
    
    state = []
    cam_states = []
    reward = []
    for history in Policy_list:
        state.append(history['state'])
        cam_states.append(history['cam_states'])
        reward.append(history["reward"])

    state = torch.from_numpy(np.array(state))
    cam_states = torch.from_numpy(np.array(cam_states))
    reward = torch.from_numpy(np.array(reward)).unsqueeze(-1)
    return state, cam_states, reward

def load_data(args, history):
    history_list = []
    for rank in range(args.workers):
        history_list += history[rank]
    
    item_cnt = len(history_list[0])
    item_name = [item for item in history_list[0]]
    data_list = [[] for i in range(item_cnt)]

    for history in history_list:
        for i,item in enumerate(history):
            data_list[i].append(history[item])

    for i in range(item_cnt):
        data_list[i] = torch.from_numpy(np.array(data_list[i]))
        if 'reward' in item_name[i]:
            data_list[i] = data_list[i].unsqueeze(-1)

    return data_list

def train(args, shared_model, optimizer_Policy, optimizer_ToM, train_modes, n_iters, curr_env_steps, ToM_count, ToM_history, Policy_history, step_history, loss_history, env=None):
    rank = args.workers
    n_iter = 0
    writer = SummaryWriter(os.path.join(args.log_dir, 'Train'))
    ptitle('Training')
    gpu_id = args.gpu_ids[rank % len(args.gpu_ids)]
    torch.manual_seed(args.seed + rank)
    env_name = args.env

    if gpu_id >= 0:
        torch.cuda.manual_seed(args.seed + rank)
        device = torch.device('cuda:' + str(gpu_id))
        if len(args.gpu_ids) > 1:
            device_share = torch.device('cpu')
        else:
            device_share = torch.device('cuda:' + str(args.gpu_ids[-1]))
    else:
        device_share = torch.device('cpu')
    #device_share = torch.device('cuda:0')
    if env == None:
        env = create_env(env_name, args)

    # # prepare model
    # model = build_model(env, args, device)
    # model = model.to(device)
    # model.train()

    params = []
    params_ToM = []
    for name,param in shared_model.named_parameters():
        if 'ToM' in name or 'other' in name:
            params_ToM.append(param)
        else:
            params.append(param)

    #ToM_len = args.ToM_frozen * args.workers * env.max_steps
    train_step_cnt = 0
    while True:
        t1 = time.time()
        while True:
            flag = True
            curr_time = time.time()
            if curr_time - t1 > 180:
                print("waiting too long for workers")
                print("train modes:",train_modes)
                return
            for rank in range(args.workers):
                if train_modes[rank] != -10:
                    flag = False    # some worker is still collecting trajectories
                    break
            if flag:
                break

        t2 = time.time()

        print("training start after waiting for {} seconds".format(t2-t1))
        train_step_cnt += 1
        state, cam_states, real_actions, reward, comm_domains, available_actions = load_data(args, Policy_history)
        step_data = None #load_data(args,step_history)
        loss_data = None #load_data(args,loss_history)
        policy_loss, value_loss, Sparsity_loss, entropies_sum =\
            optimize_Policy([step_data, loss_data], state, cam_states, real_actions, reward, comm_domains, available_actions, args, params, optimizer_Policy, shared_model, device_share, env)
        
        #n_steps = global_steps_count
        #global_steps_count += state.size()[0]
        #print(global_steps_count)
        n_steps = sum(n_iters) #* env.max_steps

        writer.add_scalar('train/policy_loss_sum', policy_loss.sum(), n_steps)
        writer.add_scalar('train/value_loss_sum', value_loss.sum(), n_steps)
        writer.add_scalar('train/Sparsity_loss_sum', Sparsity_loss.sum(), n_steps)
        writer.add_scalar('train/entropies_sum', entropies_sum.sum(), n_steps)
        writer.add_scalar('train/gamma', args.gamma, n_steps)
        print("policy loss:{}".format(policy_loss.sum().data))
        print("value loss:{}".format(value_loss.sum().data))
        print("entropies:{}".format(entropies_sum.sum().data))
        print("Policy training finished")
        print("---------------------")

        ToM_len = args.ToM_frozen * args.workers * env.max_steps
        if 'ToM' in args.model:
            if sum(ToM_count) >= ToM_len:
                print("ToM training started")
                state, cam_states, comm_domains, real_goals, available_actions = load_data(args, ToM_history)
                print("ToM data loaded")
                ToM_loss_sum, ToM_loss_avg, ToM_target_loss, ToM_target_acc = optimize_ToM(state, cam_states, comm_domains, available_actions, args, params_ToM, optimizer_ToM, shared_model, device_share, env)
                print("optimized based on ToM loss")
                
                writer.add_scalar('train/ToM_loss_sum', ToM_loss_sum.sum(), n_steps)
                writer.add_scalar('train/ToM_loss_avg', ToM_loss_avg.sum(), n_steps)
                writer.add_scalar('train/ToM_target_loss_avg', ToM_target_loss.sum(), n_steps)
                writer.add_scalar('train/ToM_target_acc_avg', ToM_target_acc.sum(), n_steps)

                for rank in range(args.workers):
                    ToM_history[rank] = []
                    ToM_count[rank] = 0
                print("---------------------")
                
                '''
                # gradually increase gamma and env steps during training
                if args.gamma_rate > 0 and n_steps >= args.start_eps * 20 * args.workers and args.gamma < args.gamma_final: # and n_steps % (args.ToM_frozen * 2 * 20) == 0:
                    if args.gamma > 0.4:
                        args.gamma = args.gamma * (1 + args.gamma_rate/2)
                    else:
                        args.gamma = args.gamma * (1 + args.gamma_rate)
                    new_env_step = int((args.gamma + 0.1)/0.2) * args.env_steps
                    env.max_steps = new_env_step
                    for rank in range(args.workers):
                        curr_env_steps[rank] = new_env_step
                assert args.gamma < 0.95
                '''
        if args.gamma_rate > 0:
            # add this one for schedule learning
            if n_steps >= args.start_eps * 20 * args.workers and args.gamma < args.gamma_final and train_step_cnt % (args.ToM_frozen) == 0:
                if args.gamma > 0.4:
                    args.gamma = args.gamma * (1 + args.gamma_rate/2)
                else:
                    args.gamma = args.gamma * (1 + args.gamma_rate)
                new_env_step = int((args.gamma + 0.1)/0.2) * args.env_steps
                env.max_steps = new_env_step
                for rank in range(args.workers):
                    curr_env_steps[rank] = new_env_step

            print("gamma:",args.gamma)
            assert args.gamma < 0.95
        for rank in range(args.workers):
            Policy_history[rank] = []
            if train_modes[rank] == -100:
                return
            train_modes[rank] = -1

        if train_modes[0] == -100:
            env.close()
            break