import argparse
import copy
import gzip
import heapq
import itertools
import os
import pickle
from collections import defaultdict
from itertools import count

import numpy as np
from scipy.stats import norm
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
import string


def f64(n, x):
    '''输入任意十进制数n要转换的进制x， 返回对应转换后的数。注意进制范围仅限2到64, 并且进制的值必须小于待转换的十进制数'''
    if x < 2 or x > 64 or x > n:
        return 'ERROR'

    li1 = [i for i in range(10)]
    li2 = list(string.ascii_letters)
    li3 = ['@', '_']
    li = li1 + li2 + li3
    # li = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '@', '_']

    # 3、进制转换
    result = ''
    while True:
        m = n // x  # 商
        r = n % x  # 余数
        print(r)
        if n < x:
            result = str(li[n]) + result
            break
        result = str(li[r]) + result
        n = m

    return result


_dev = [torch.device('cpu')]
tf = lambda x: torch.FloatTensor(x).to(_dev[0])
tl = lambda x: torch.LongTensor(x).to(_dev[0])


def make_mlp(l, act=nn.LeakyReLU(), tail=[]):
    """makes an MLP with no top layer activation"""
    return nn.Sequential(*(sum(
        [[nn.Linear(i, o)] + ([act] if n < len(l)-2 else [])
         for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))


class MAGridEnv:
    # multi-agent Grid Env
    def __init__(self, horizon, agent_num, ndim=2, xrange=[-1, 1], func=None, allow_backward=False):
        self.horizon = horizon
        self.agent_num = agent_num           # The number of agents
        self.start = [xrange[0]] * ndim      # The initialization state
        self.ndim = ndim                     # The grid dim
        self.width = xrange[1] - xrange[0]   # The grid size
        self.func = (
            (lambda x: ((np.cos(x * 50) + 1) * norm.pdf(x * 5)).prod(-1) + 0.01)
            if func is None else func)
        self.xspace = np.linspace(*xrange, horizon)
        self.allow_backward = allow_backward
        self.R0 = 1e-5                      # the constant reward

        self.actions_possible_set = tf(self.possible_actions()).to(torch.int64)
        self.mcmc_actions_possible_set = tf(self.possible_actions_mcmc()).to(torch.int64)
        self._true_density = None

    def obs(self, s=None):
        # Transform the digital state as one-hot vector
        s = np.int32(self._state if s is None else s)
        z = np.zeros((self.agent_num, self.horizon * self.ndim), dtype=np.float32)
        for idx in range(self.agent_num):
            z[idx][np.arange(len(s[idx])) * self.horizon + s[idx]] = 1
        return z

    def s2x(self, s):
        # Transform the digital state as the truth location
        return (self.obs(s).reshape((self.agent_num, self.ndim, self.horizon)) * self.xspace[None, :]).sum(-1)

    def reset(self):
        # all the agents is reset in the same grid
        self._state = np.array([np.int32([0] * self.ndim)] * self.agent_num)
        self._step = 0
        return self.obs(), self.reward(self.s2x(self._state)), self._state

    def step(self, a, s=None):
        if self.allow_backward:
            return self.step_chain(a, s)
        return self.step_dag(a, s)

    def step_dag(self, a, s=None):
        _s = s
        s = (self._state if s is None else s) + 0

        for idx in range(self.agent_num):
            if a[idx] < self.ndim:
                s[idx][a[idx]] += 1

        # All the agents satisfy the condition, return done
        done = all(s.max(-1) >= self.horizon - 1) or all(a == self.ndim)

        # may be wrong
        if not done:
            for idx in range(self.agent_num):
                for i in range(self.ndim):
                    if s[idx][i] >= self.horizon - 1:
                        s[idx][i] -= 1

        if _s is None:
            self._state = s
            self._step += 1
        return self.obs(s), np.array([0 for _ in range(self.agent_num)]) if not done else self.reward(self.s2x(s)), done, s

    def step_chain(self, a, s=None):
        _s = s
        s = (self._state if s is None else s) + 0
        sc = s + 0
        reverse_a = copy.deepcopy(a)

        for idx in range(self.agent_num):
            if a[idx] < self.ndim:
                s[idx][a[idx]] = min(s[idx][a[idx]]+1, self.horizon-1)
            if a[idx] >= self.ndim:
                s[idx][a[idx]-self.ndim] = max(s[idx][a[idx]-self.ndim]-1, 0)

            reverse_a[idx] = ((a[idx] + self.ndim) % (2 * self.ndim)) if any(sc[idx] != s[idx]) else a[idx]

        if _s is None:
            self._state = s
            self._step += 1
        return self.obs(s), self.reward(self.s2x(s)), s, reverse_a

    def parent_transitions(self, s, used_stop_action):
        if used_stop_action:
            return [self.obs(s)], [[self.ndim for _ in range(self.agent_num)]]
        # parents = []
        # actions = []

        parents_each_agent = [[] for _ in range(self.agent_num)]
        actions_each_agent = [[] for _ in range(self.agent_num)]

        for agent_idx in range(self.agent_num):
            for i in range(self.ndim):
                if s[agent_idx][i] > 0:
                    sp = s[agent_idx] + 0
                    sp[i] -= 1
                    if sp.max() == self.horizon - 1:  # can't have a terminal parent
                        continue
                    parents_each_agent[agent_idx] += [sp]
                    actions_each_agent[agent_idx] += [i]

        for agent_idx in range(self.agent_num):
            if len(parents_each_agent[agent_idx]) == 0:
                parents_each_agent[agent_idx] += [s[agent_idx]]
            if len(actions_each_agent[agent_idx]) == 0:
                actions_each_agent[agent_idx] += [self.ndim]

        # parents = list(itertools.product(*parents_each_agent))
        # actions = list(itertools.product(*actions_each_agent))
        parents = [self.obs(np.array(list(parent))) for parent in list(itertools.product(*parents_each_agent))]
        actions = [np.array(list(action)) for action in list(itertools.product(*actions_each_agent))]
        return parents, actions

    def reward(self, x):
        ax = abs(x)
        return (ax > 0.5).prod(-1) * 0.5 + ((ax < 0.8) * (ax > 0.6)).prod(-1) * 2 + self.R0

    def possible_states(self):
        pass

    def possible_actions(self):
        """
        1 transform the action into index or transform the index into joint action
        2 calculate all possible actions
        """
        action_num = self.agent_num ** (self.ndim+1)
        action_each_agent = [list(range(self.ndim+1)) for _ in range(self.agent_num)]
        actions_possible_set = [np.array(list(action)) for action in list(itertools.product(*action_each_agent))]
        return actions_possible_set

    def possible_actions_mcmc(self):
        """
        For MCMC based methods
        1 transform the action into index or transform the index into joint action
        2 calculate all possible actions
        """
        action_num = self.agent_num ** (self.ndim * 2)
        action_each_agent = [list(range(self.ndim * 2)) for _ in range(self.agent_num)]
        actions_possible_set = [np.array(list(action)) for action in list(itertools.product(*action_each_agent))]
        return actions_possible_set

    def action_to_index(self, actions):
        """
        1 transform the action into index or transform the index into joint action
        2 calculate all possible actions
        """

        actions_length = actions.shape[0]
        action_dim_weight = tf([[(self.ndim+1)**(self.agent_num-i-1)
                                 for i in range(self.agent_num)]]).repeat(actions_length, 1)
        actions_idx = torch.sum(actions * action_dim_weight, dim=1)
        return actions_idx

    def true_density(self):
        if self._true_density is not None:
            return self._true_density

        all_int_states = np.int32(list(itertools.product(*[list(range(self.horizon))]*self.ndim)))
        state_mask = np.array([len(self.parent_transitions(s, False)[0]) > 0 or sum(s) == 0
                               for s in all_int_states])
        all_xs = (np.float32(all_int_states) / (self.horizon-1) *
                  (self.xspace[-1] - self.xspace[0]) + self.xspace[0])
        traj_rewards = self.func(all_xs)[state_mask]
        self._true_density = (traj_rewards / traj_rewards.sum(),
                              list(map(tuple, all_int_states[state_mask])),
                              traj_rewards)
        return self._true_density

    def true_density_2(self):
        if self._true_density is not None:
            return self._true_density

        all_int_states_list = list(itertools.product(*[list(range(self.horizon))]*(self.ndim*self.agent_num)))
        all_int_states = np.int32([np.array(list(i)).reshape(self.agent_num, self.ndim) for i in all_int_states_list])
        state_mask = np.array([len(self.parent_transitions(s, False)[0]) > 0 or sum(s) == 0
                               for s in all_int_states])
        traj_rewards = np.array([np.mean(self.reward(self.s2x(s))) for s in all_int_states])[state_mask]
        all_int_states_tuple = [tuple(s.flatten()) for s in all_int_states]
        self._true_density = (traj_rewards / traj_rewards.sum(),
                              all_int_states_tuple,
                              traj_rewards)
        return self._true_density

        # state_per_agent_tuple = list(itertools.product(*[list(range(self.horizon))]*self.ndim))
        # state_per_agent = [list(i) for i in state_per_agent_tuple]
        # all_int_states = np.int32(list(itertools.product(*state_per_agent)))


if __name__ == '__main__':
    pass