
import argparse
import copy
import gzip
import heapq
import itertools
import os
import sys
import pickle
from collections import defaultdict
from itertools import count

import numpy as np
from scipy.stats import norm
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import string
from MAFlowEnv import MAGridEnv
import pandas as pd

parser = argparse.ArgumentParser()

parser.add_argument("--save_path", default='results/', type=str)
parser.add_argument("--device", default='cpu', type=str)
parser.add_argument("--progress", action='store_true')

#
parser.add_argument("--method", default='flownet', type=str)
parser.add_argument("--learning_rate", default=1e-4, help="Learning rate", type=float)
parser.add_argument("--opt", default='adam', type=str)
parser.add_argument("--adam_beta1", default=0.9, type=float)
parser.add_argument("--adam_beta2", default=0.999, type=float)
parser.add_argument("--momentum", default=0.9, type=float)
parser.add_argument("--mbsize", default=16, help="Minibatch size", type=int)
parser.add_argument("--train_to_sample_ratio", default=1, type=float)
parser.add_argument("--n_hid", default=512, type=int)
parser.add_argument("--n_layers", default=2, type=int)
parser.add_argument("--n_train_steps", default=20000, type=int)
parser.add_argument("--num_empirical_loss", default=200000, type=int,
                    help="Number of samples used to compute the empirical distribution loss")
# Env
parser.add_argument('--func', default='corners')
parser.add_argument("--horizon", default=8, type=int)
parser.add_argument("--ndim", default=4, type=int)
parser.add_argument("--agent_num", default=4, type=int)


# MCMC
parser.add_argument("--bufsize", default=16, help="MCMC buffer size", type=int)

# Flownet
parser.add_argument("--bootstrap_tau", default=0., type=float)
parser.add_argument("--replay_strategy", default='none', type=str) # top_k none
parser.add_argument("--replay_sample_size", default=2, type=int)
parser.add_argument("--replay_buf_size", default=100, type=float)

# PPO
parser.add_argument("--ppo_num_epochs", default=32, type=int) # number of SGD steps per epoch
parser.add_argument("--ppo_epoch_size", default=16, type=int) # number of sampled minibatches per epoch
parser.add_argument("--ppo_clip", default=0.2, type=float)
parser.add_argument("--ppo_entropy_coef", default=1e-1, type=float)
parser.add_argument("--clip_grad_norm", default=0., type=float)

# SAC
parser.add_argument("--sac_alpha", default=0.98*np.log(1/3), type=float)


_dev = [torch.device('cpu')]
tf = lambda x: torch.FloatTensor(x).to(_dev[0])
tl = lambda x: torch.LongTensor(x).to(_dev[0])
# DEFlowNet_v2

def set_device(dev):
    _dev[0] = dev


def make_mlp(l, act=nn.LeakyReLU(), tail=[]):
    """makes an MLP with no top layer activation"""
    return nn.Sequential(*(sum(
        [[nn.Linear(i, o)] + ([act] if n < len(l)-2 else [])
         for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))


class ReplayBuffer(object):
    def __init__(self, args, env):
        self.buf = []
        self.strat = args.replay_strategy
        self.sample_size = args.replay_sample_size
        self.bufsize = args.replay_buf_size
        self.env = env

    def add(self, x, r_x):
        if self.strat == 'top_k':
            if len(self.buf) < self.bufsize or r_x > self.buf[0][0]:
                self.buf = sorted(self.buf + [(r_x, x)])[-self.bufsize:]

    def sample(self):
        if not len(self.buf):
            return []
        idxs = np.random.randint(0, len(self.buf), self.sample_size)
        return sum([self.generate_backward(*self.buf[i]) for i in idxs], [])

    def generate_backward(self, r, s0):
        s = np.int8(s0)
        os0 = self.env.obs(s)
        used_stop_action = s.max() < self.env.horizon - 1
        done = True
        traj = []
        while s.sum() > 0:
            parents, actions = self.env.parent_transitions(s, used_stop_action)
            # add the transition
            traj.append([tf(i) for i in (parents, actions, [r], [self.env.obs(s)], [done])])
            # Then randomly choose a parent state
            if not used_stop_action:
                i = np.random.randint(0, len(parents))
                a = actions[i]
                s[a] -= 1
            # Values for intermediary trajectory states:
            used_stop_action = False
            done = False
            r = 0
        return traj


class DEFlowNetAgent(object):
    def __init__(self, args, envs):
        self.args = args
        self.model = make_mlp([args.horizon * args.ndim * args.agent_num] +
                              [args.n_hid] * args.n_layers +
                              [args.ndim + 1])  # +1 for stop action, +1 for V
        self.model.to(args.dev)
        self.model_list = []

        for _ in range(args.agent_num):
            self.model_list.append(copy.deepcopy(self.model))

        self.p = nn.ParameterList()
        for idx in range(args.agent_num):
            self.p.extend(self.model_list[idx].parameters())

        self.target = copy.deepcopy(self.model)
        self.target_model_list = copy.deepcopy(self.model_list)
        self.envs = envs
        self.ndim = args.ndim
        self.tau = args.bootstrap_tau
        self.replay = ReplayBuffer(args, envs[0])

        self.soft_plus = torch.nn.Softplus()
    def parameters(self):
        return self.p

    def observation(self, state, idx):
        state_list = list(torch.chunk(state, self.args.agent_num, dim=1))
        agent_idx_state = state_list.pop(idx)
        observation_agent = []
        observation_agent.append(agent_idx_state)
        observation_agent.extend(state_list)
        return torch.stack(observation_agent, dim=1)

    def sample_many(self, mbsize, all_visited):
        batch = []
        batch += self.replay.sample()
        s = tf([i.reset()[0] for i in self.envs])
        done = [False] * mbsize
        while not all(done):
            with torch.no_grad():
                # sampling version v1
                joint_acts = []
                for agent_idx in range(self.args.agent_num):
                    flows = self.model_list[agent_idx](s.view(-1, self.args.agent_num * self.args.horizon * self.args.ndim))
                    # pol = Categorical(logits=(torch.div(flows[:, :-1], torch.unsqueeze(flows[:, -1], dim=1))))
                    pol = Categorical(logits=flows)
                    acts = pol.sample()
                    joint_acts.append(acts)
                joint_acts = torch.stack(joint_acts, dim=1)

            step = [i.step(a) for i, a in zip([e for d, e in zip(done, self.envs) if not d], joint_acts)]
            p_a = [self.envs[0].parent_transitions(sp_state, all(a == self.ndim))
                   for a, (sp, r, done, sp_state) in zip(joint_acts, step)]
            batch += [[tf(i) for i in (p, a, [r], [sp], [d])]
                      for (p, a), (sp, r, d, _) in zip(p_a, step)]
            c = count(0)
            m = {j: next(c) for j in range(mbsize) if not done[j]}
            done = [bool(d or step[m[i]][2]) for i, d in enumerate(done)]
            s = tf([i[0] for i in step if not i[2]])
            for (_, r, d, sp) in step:
                if d:
                    all_visited.append(tuple(sp.flatten()))
                    self.replay.add(tuple(sp), r)
        return batch

    def learn_from(self, it, batch):
        epi = 0
        loginf = tf([1000])
        batch_idxs = tl(sum([[i]*len(parents) for i, (parents,_,_,_,_) in enumerate(batch)], []))
        parents, actions, r, sp, done = map(torch.cat, zip(*batch))
        parents_dim = parents.shape[0]
        independent_flow, state_flow = [], []
        for agent_idx in range(self.args.agent_num):
            parents_ob = self.observation(parents, agent_idx)
            logits = self.model_list[agent_idx](parents_ob.view(parents_dim, -1))
            independent_flow.append(logits[torch.arange(parents.shape[0]), actions[:, agent_idx].long()])
        independent_flow = torch.stack(independent_flow, dim=1)

        parents_Qsa = torch.sum(independent_flow, dim=1, keepdim=True)

        in_flow = torch.log(torch.zeros((sp.shape[0],))
                            .index_add_(0, batch_idxs, torch.exp(parents_Qsa.squeeze(1))))
        if self.tau > 0:
            with torch.no_grad():
                independent_out_flow, state_out_flow = [], []
                for agent_idx in range(self.args.agent_num):
                    ob = self.observation(sp, agent_idx)
                    out_logits = self.target_model_list[agent_idx](
                        ob.view(-1, self.args.horizon * self.args.ndim * self.args.agent_num))
                    independent_out_flow.append(out_logits)

                next_q = []
                action_set = self.envs[0].actions_possible_set
                for action_idx in range(len(self.envs[0].actions_possible_set)):
                    next_q_temp = []
                    for agent_idx in range(self.args.agent_num):
                        next_q_temp.append(independent_out_flow[agent_idx][:, action_set[action_idx][agent_idx]])

                    next_q_temp_tensor = torch.stack(next_q_temp, dim=1)

                    next_Qsa = torch.sum(next_q_temp_tensor, dim=1, keepdim=True)

                    next_q.append(next_Qsa.squeeze(1))
                next_q = torch.stack(next_q, dim=1)
        else:
            independent_out_flow, state_out_flow = [], []
            for agent_idx in range(self.args.agent_num):
                ob = self.observation(sp, agent_idx)
                out_logits = self.model_list[agent_idx](
                    ob.view(-1, self.args.horizon * self.args.ndim * self.args.agent_num))
                independent_out_flow.append(out_logits)
            next_q = []
            action_set = self.envs[0].actions_possible_set
            for action_idx in range(len(self.envs[0].actions_possible_set)):
                next_q_temp = []
                for agent_idx in range(self.args.agent_num):
                    next_q_temp.append(independent_out_flow[agent_idx][:, action_set[action_idx][agent_idx]])

                next_q_temp_tensor = torch.stack(next_q_temp, dim=1)

                next_Qsa = torch.sum(next_q_temp_tensor, dim=1, keepdim=True)

                next_q.append(next_Qsa.squeeze(1))
            next_q = torch.stack(next_q, dim=1)
            # next_q = self.model(sp.view(-1, self.args.horizon * self.args.ndim * self.args.agent_num))
        next_qd = next_q * (1-done).unsqueeze(1) + done.unsqueeze(1) * (-loginf)
        out_flow = torch.logsumexp(torch.cat([torch.log(torch.mean(r, dim=1)+epi)[:, None], next_qd], 1), 1)

        loss = (in_flow - out_flow).pow(2).mean()

        with torch.no_grad():
            term_loss = ((in_flow - out_flow) * done).pow(2).sum() / (done.sum() + 1e-20)
            flow_loss = ((in_flow - out_flow) * (1-done)).pow(2).sum() / ((1-done).sum() + 1e-20)

        if self.tau > 0:
            for agent_idx in range(self.args.agent_num):
                for a, b in zip(self.model_list[agent_idx].parameters(), self.target_model_list[agent_idx].parameters()):
                    b.data.mul_(1-self.tau).add_(self.tau*a)

        return loss, term_loss, flow_loss


def make_opt(params, args):
    params = list(params)
    if not len(params):
        return None
    if args.opt == 'adam':
        opt = torch.optim.Adam(params, args.learning_rate,
                               betas=(args.adam_beta1, args.adam_beta2))
    elif args.opt == 'msgd':
        opt = torch.optim.SGD(params, args.learning_rate, momentum=args.momentum)
    return opt


def compute_empirical_distribution_error(env, visited):
    if not len(visited):
        return 1, 100
    hist = defaultdict(int)
    for i in visited:
        hist[i] += 1
    td, end_states, true_r = env.true_density_2()
    true_density = tf(td)
    Z = sum([hist[i] for i in end_states])
    estimated_density = tf([hist[i] / Z for i in end_states])
    k1 = abs(estimated_density - true_density).mean().item() * len(end_states)
    # KL divergence
    kl = (true_density * torch.log(estimated_density / true_density)).sum().item()

    mods_number = 320
    mods = np.array(end_states)[np.argpartition(td, -mods_number)[-mods_number:]].tolist()
    mods_gene = np.array(visited).tolist()
    mods_found_list = [k for k in mods if k in mods_gene]

    mods_reward = td[np.argpartition(td, -mods_number)[-mods_number:]]
    mods_normal_reward = mods_reward / np.sum(mods_reward)
    true_density_found = tf(mods_normal_reward)
    Z_found = sum([hist[tuple(i)] for i in mods])
    estimated_density_found = tf([hist[tuple(i)] / (Z_found+1) for i in mods])
    k1_found = abs(estimated_density_found - true_density_found).mean().item() * mods_number

    # note: calculate the reward:
    average_return = np.mean(np.array([td[end_states.index(tuple(idx))] for idx in mods_gene]))
    max_return = max([td[end_states.index(tuple(idx))] for idx in mods_gene])
    return k1, len(mods_found_list), k1_found, average_return, max_return


def main(args):
    args.dev = torch.device(args.device)
    set_device(args.dev)

    args.is_mcmc = args.method in ['mars', 'mcmc']
    # args.is_mcmc = True
    env = MAGridEnv(args.horizon, args.agent_num, args.ndim, allow_backward=args.is_mcmc)
    envs = [MAGridEnv(args.horizon, args.agent_num, args.ndim, allow_backward=args.is_mcmc)
            for i in range(args.bufsize)]
    test_env = MAGridEnv(args.horizon, args.agent_num, args.ndim, allow_backward=args.is_mcmc)
    ndim = args.ndim

    if args.method == 'flownet':
        agent = DEFlowNetAgent(args, envs)
    else:
        sys.exit()


    opt = make_opt(agent.parameters(), args)

    # metrics
    all_losses = []
    all_visited = []
    empirical_distrib_losses = []

    ttsr = max(int(args.train_to_sample_ratio), 1)
    sttr = max(int(1/args.train_to_sample_ratio), 1) # sample to train ratio

    if args.method == 'ppo':
        ttsr = args.ppo_num_epochs
        sttr = args.ppo_epoch_size

    reward_results = []
    for i in tqdm(range(args.n_train_steps+1), disable=args.progress):
        data = []
        for j in range(sttr):
            data += agent.sample_many(args.mbsize, all_visited)
        for j in range(ttsr):
            losses = agent.learn_from(i * ttsr + j, data)    # returns (opt loss, *metrics)
            if losses is not None:
                losses[0].backward()
                if args.clip_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(agent.parameters(),
                                                   args.clip_grad_norm)
                opt.step()
                opt.zero_grad()
                all_losses.append([i.item() for i in losses])
                # print(losses[0])
        if not i % 1000:
            all_visited_all = []
            for i in range(20):
                all_visited = []
                agent.sample_many(args.mbsize, all_visited)
                all_visited_all.extend(all_visited)
            empirical_distrib_losses.append(
                compute_empirical_distribution_error(env, all_visited_all[-args.num_empirical_loss:]))
            if not args.progress:
                k1, kl, k1_found, avg_return, max_return = empirical_distrib_losses[-1]
                # print('empirical L1 distance', k1, 'KL', kl)
                print('empirical L1 distance', k1, 'KL', kl, 'k1_found', k1_found, 'return')
                if len(all_losses):
                    print(*[f'{np.mean([i[j] for i in all_losses[-100:]]):.5f}'
                            for j in range(len(all_losses[0]))])

    root = os.path.split(args.save_path)[0]
    os.makedirs(root, exist_ok=True)
    pickle.dump(
        {'losses': np.float32(all_losses),
         #'model': agent.model.to('cpu') if agent.model else None,
         'params': [i.data.to('cpu').numpy() for i in agent.parameters()],
         'visited': np.int8(all_visited),
         'emp_dist_loss': empirical_distrib_losses,
         # 'true_d': env.true_density()[0],
         'args':args},
        gzip.open(args.save_path, 'wb'))


if __name__ == '__main__':
    args = parser.parse_args()
    torch.set_num_threads(1)
    main(args)