'''
The sampling process and learning process of FCN
'''
import argparse
import copy
from itertools import count

import numpy as np
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

parser = argparse.ArgumentParser()

parser.add_argument("--save_path", default='results/flow_FDN.pkl.gz', type=str)
parser.add_argument("--device", default='cpu', type=str)
parser.add_argument("--progress", action='store_true')

#
parser.add_argument("--method", default='flownet', type=str)
parser.add_argument("--opt", default='adam', type=str)
parser.add_argument("--adam_beta1", default=0.9, type=float)
parser.add_argument("--adam_beta2", default=0.999, type=float)
parser.add_argument("--momentum", default=0.9, type=float)
parser.add_argument("--mbsize", default=16, help="Minibatch size", type=int)
parser.add_argument("--train_to_sample_ratio", default=1, type=float)
parser.add_argument("--n_hid", default=512, type=int)
parser.add_argument("--n_layers", default=2, type=int)
parser.add_argument("--n_train_steps", default=20000, type=int)
parser.add_argument("--num_empirical_loss", default=200000, type=int,
                    help="Number of samples used to compute the empirical distribution loss")
# Env
parser.add_argument('--func', default='corners')
parser.add_argument("--horizon", default=8, type=int)
parser.add_argument("--ndim", default=4, type=int)
parser.add_argument("--agent_num", default=4, type=int)



# Flownet
parser.add_argument("--bootstrap_tau", default=0., type=float)
parser.add_argument("--replay_strategy", default='none', type=str) # top_k none
parser.add_argument("--replay_sample_size", default=2, type=int)
parser.add_argument("--replay_buf_size", default=100, type=float)


_dev = [torch.device('cpu')]
tf = lambda x: torch.FloatTensor(x).to(_dev[0])
tl = lambda x: torch.LongTensor(x).to(_dev[0])


def set_device(dev):
    _dev[0] = dev


def make_mlp(l, act=nn.LeakyReLU(), tail=[]):
    """makes an MLP with no top layer activation"""
    return nn.Sequential(*(sum(
        [[nn.Linear(i, o)] + ([act] if n < len(l)-2 else [])
         for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))


class ReplayBuffer(object):
    def __init__(self, args, env):
        self.buf = []
        self.strat = args.replay_strategy
        self.sample_size = args.replay_sample_size
        self.bufsize = args.replay_buf_size
        self.env = env

    def add(self, x, r_x):
        if self.strat == 'top_k':
            if len(self.buf) < self.bufsize or r_x > self.buf[0][0]:
                self.buf = sorted(self.buf + [(r_x, x)])[-self.bufsize:]

    def sample(self):
        if not len(self.buf):
            return []
        idxs = np.random.randint(0, len(self.buf), self.sample_size)
        return sum([self.generate_backward(*self.buf[i]) for i in idxs], [])

    def generate_backward(self, r, s0):
        s = np.int8(s0)
        os0 = self.env.obs(s)
        used_stop_action = s.max() < self.env.horizon - 1
        done = True
        # Now we work backward from that last transition
        traj = []
        while s.sum() > 0:
            parents, actions = self.env.parent_transitions(s, used_stop_action)
            # add the transition
            traj.append([tf(i) for i in (parents, actions, [r], [self.env.obs(s)], [done])])
            # Then randomly choose a parent state
            if not used_stop_action:
                i = np.random.randint(0, len(parents))
                a = actions[i]
                s[a] -= 1
            # Values for intermediary trajectory states:
            used_stop_action = False
            done = False
            r = 0
        return traj


class FCNAgent(object):
    """
        FCN algorithm
    """
    def __init__(self, args, envs):
        self.args = args
        self.model = make_mlp([args.horizon * args.ndim * args.agent_num] +
                              [args.n_hid] * args.n_layers +
                              [args.ndim + 1])
        self.model.to(args.dev)
        self.model_list = []

        for _ in range(args.agent_num):
            self.model_list.append(copy.deepcopy(self.model))

        self.p = nn.ParameterList()
        for idx in range(args.agent_num):
            self.p.extend(self.model_list[idx].parameters())

        self.target = copy.deepcopy(self.model)
        self.target_model_list = copy.deepcopy(self.model_list)
        self.envs = envs
        self.ndim = args.ndim
        self.tau = args.bootstrap_tau
        self.replay = ReplayBuffer(args, envs[0])

        self.soft_plus = torch.nn.Softplus()

    def parameters(self):
        return self.p

    def observation(self, state, idx):
        state_list = list(torch.chunk(state, self.args.agent_num, dim=1))
        agent_idx_state = state_list.pop(idx)
        observation_agent = []
        observation_agent.append(agent_idx_state)
        observation_agent.extend(state_list)
        return torch.stack(observation_agent, dim=1)

    def sample_many(self, mbsize, all_visited):
        batch = []
        batch += self.replay.sample()
        s = tf([i.reset()[0] for i in self.envs])
        done = [False] * mbsize
        while not all(done):
            with torch.no_grad():
                action_flow_all_agents = []
                for agent_idx in range(self.args.agent_num):
                    ob = self.observation(s, agent_idx)
                    flows = self.model_list[agent_idx](ob.view(-1, self.args.agent_num
                                                              * self.args.horizon * self.args.ndim))
                    action_flow_all_agents.append(flows)

                joint_flow = []

                action_set = self.envs[0].actions_possible_set
                for action_idx in range(len(self.envs[0].actions_possible_set)):
                    joint_flow_temp = []
                    for agent_idx in range(self.args.agent_num):
                        joint_flow_temp.append(action_flow_all_agents[agent_idx][:, action_set[action_idx][agent_idx]])
                    joint_flow_temp = torch.stack(joint_flow_temp, dim=1)
                    joint_flow.append(torch.sum(joint_flow_temp, dim=1))
                joint_flow = torch.stack(joint_flow, dim=1)

                acts_index = Categorical(logits=joint_flow).sample()
                joint_acts = self.envs[0].actions_possible_set.index_select(0, acts_index)

            step = [i.step(a) for i, a in zip([e for d, e in zip(done, self.envs) if not d], joint_acts)]
            p_a = [self.envs[0].parent_transitions(sp_state, all(a == self.ndim))
                   for a, (sp, r, done, sp_state) in zip(joint_acts, step)]
            batch += [[tf(i) for i in (p, a, [r], [sp], [d])]
                      for (p, a), (sp, r, d, _) in zip(p_a, step)]
            c = count(0)
            m = {j: next(c) for j in range(mbsize) if not done[j]}
            done = [bool(d or step[m[i]][2]) for i, d in enumerate(done)]
            s = tf([i[0] for i in step if not i[2]])
            for (_, r, d, sp) in step:
                if d:
                    all_visited.append(tuple(sp.flatten()))
                    self.replay.add(tuple(sp), r)
        return batch

    def learn_from(self, it, batch):
        epi = 0
        loginf = tf([1000])
        batch_idxs = tl(sum([[i]*len(parents) for i, (parents,_,_,_,_) in enumerate(batch)], []))
        parents, actions, r, sp, done = map(torch.cat, zip(*batch))
        parents_dim = parents.shape[0]

        # flow decomposition network
        independent_flow, state_flow = [], []
        for agent_idx in range(self.args.agent_num):
            parents_ob = self.observation(parents, agent_idx)
            logits = self.model_list[agent_idx](parents_ob.view(parents_dim, -1))
            independent_flow.append(logits[torch.arange(parents.shape[0]), actions[:, agent_idx].long()])
        independent_flow = torch.stack(independent_flow, dim=1)

        parents_Qsa = torch.sum(independent_flow, dim=1, keepdim=True)

        in_flow = torch.log(torch.zeros((sp.shape[0],))
                            .index_add_(0, batch_idxs, torch.exp(parents_Qsa.squeeze(1))))
        if self.tau > 0:
            with torch.no_grad():
                independent_out_flow, state_out_flow = [], []
                for agent_idx in range(self.args.agent_num):
                    ob = self.observation(sp, agent_idx)
                    out_logits = self.target_model_list[agent_idx](
                        ob.view(-1, self.args.horizon * self.args.ndim * self.args.agent_num))
                    independent_out_flow.append(out_logits)

                # calculate all actions
                next_q = []
                action_set = self.envs[0].actions_possible_set
                for action_idx in range(len(self.envs[0].actions_possible_set)):
                    next_q_temp = []
                    for agent_idx in range(self.args.agent_num):
                        next_q_temp.append(independent_out_flow[agent_idx][:, action_set[action_idx][agent_idx]])


                    next_q_temp_tensor = torch.stack(next_q_temp, dim=1)
                    next_Qsa = torch.sum(next_q_temp_tensor, dim=1, keepdim=True)

                    next_q.append(next_Qsa.squeeze(1))
                next_q = torch.stack(next_q, dim=1)
        else:
            # calculate the out flow as a decentralized fashion
            independent_out_flow, state_out_flow = [], []
            for agent_idx in range(self.args.agent_num):
                ob = self.observation(sp, agent_idx)
                out_logits = self.model_list[agent_idx](
                    ob.view(-1, self.args.horizon * self.args.ndim * self.args.agent_num))
                independent_out_flow.append(out_logits)


            # calculate all actions
            next_q = []
            action_set = self.envs[0].actions_possible_set
            for action_idx in range(len(self.envs[0].actions_possible_set)):
                next_q_temp = []
                for agent_idx in range(self.args.agent_num):
                    next_q_temp.append(independent_out_flow[agent_idx][:, action_set[action_idx][agent_idx]])

                next_q_temp_tensor = torch.stack(next_q_temp, dim=1)
                next_Qsa = torch.sum(next_q_temp_tensor, dim=1, keepdim=True)

                next_q.append(next_Qsa.squeeze(1))
            next_q = torch.stack(next_q, dim=1)
        next_qd = next_q * (1-done).unsqueeze(1) + done.unsqueeze(1) * (-loginf)
        out_flow = torch.logsumexp(torch.cat([torch.log(torch.mean(r, dim=1)+epi)[:, None], next_qd], 1), 1)

        loss = (in_flow - out_flow).pow(2).mean()

        with torch.no_grad():
            term_loss = ((in_flow - out_flow) * done).pow(2).sum() / (done.sum() + 1e-20)
            flow_loss = ((in_flow - out_flow) * (1-done)).pow(2).sum() / ((1-done).sum() + 1e-20)

        if self.tau > 0:
            for agent_idx in range(self.args.agent_num):
                for a, b in zip(self.model_list[agent_idx].parameters(), self.target_model_list[agent_idx].parameters()):
                    b.data.mul_(1-self.tau).add_(self.tau*a)

        return loss, term_loss, flow_loss
