from typing import Dict, Any, Sequence
from collections import defaultdict, namedtuple

import argparse
import os
import logging
import copy
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

from dataclasses import dataclass, field
from pettingzoo.mpe import simple_tag_v2


logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())


def to_one_hot(index_array, size):
    """Convert index array to one-hot array"""
    buf = np.zeros((index_array.size, size))
    buf[range(index_array.size), index_array] = 1.0
    return buf


class QNet(nn.Module):
    def __init__(self, obs_space: gym.spaces.Space, act_space: gym.spaces.Space):
        super(QNet, self).__init__()
        self._obs_space = obs_space
        self._act_space = act_space

        # network
        assert isinstance(self._obs_space, gym.spaces.Box), self._obs_space
        assert isinstance(self._act_space, gym.spaces.Discrete), self._act_space
        self._f1 = nn.Linear(self._obs_space.shape[0], 256)
        self._f2 = nn.Linear(256, 64)
        self._out = nn.Linear(64, self._act_space.n)

    def forward(self, state: np.ndarray):
        """Forward to get values

        :param state: np.ndarray
            the state input
        :return: values
        """

        state = torch.tensor(state)
        x = F.relu(self._f1(state))
        h = F.relu(self._f2(x))
        values = self._out(h)

        return values


class LinearFeatureNet(nn.Module):
    def __init__(
        self, obs_space: gym.spaces.Box, act_space: gym.spaces.Discrete, out_num: int
    ):
        super(LinearFeatureNet, self).__init__()

        self._obs_space = obs_space
        self._act_space = act_space

        # network
        self._f1 = nn.Linear(self._obs_space.shape[0] + self._act_space.n, 128)
        self._f2 = nn.Linear(128, 64)
        self._out = nn.Linear(64, out_num)

    def forward(self, state, action):
        # the action should be one hot
        raw_inputs = np.concatenate([state, action], axis=-1).astype(np.float32)
        raw_inputs = torch.tensor(raw_inputs)
        x = F.relu(self._f1(raw_inputs))
        h = F.relu(self._f2(x))
        feature = self._out(h)
        return feature


class Factorization:
    def __init__(
        self,
        possible_agents: Sequence[str],
        agent_obs_spaces: Dict[str, gym.spaces.Box],
        agent_act_spaces: Dict[str, gym.spaces.Discrete],
        feature_length: int,
    ):

        self._obs_spaces = agent_obs_spaces
        self._act_spaces = agent_act_spaces
        self._agents = possible_agents

        # check observation space: should be box
        for aid, obs_space in agent_obs_spaces.items():
            assert isinstance(obs_space, gym.spaces.Box), (aid, obs_space)
        # check action space: should be discrete
        for aid, action_space in agent_act_spaces.items():
            assert isinstance(action_space, gym.spaces.Discrete), (aid, action_space)

        # build sub nets: all agents share the same feature vector network layer size (not variables)
        self._feature_nets = {
            k: LinearFeatureNet(
                self._obs_spaces[k], self._act_spaces[k], feature_length
            )
            for k in self._agents
        }

        self._parameters = []
        for net in self._feature_nets.values():
            self._parameters.extend(net.parameters())

        self._state_dict = {
            aid: fnet.state_dict() for aid, fnet in self._feature_nets.items()
        }

    def parameters(self):
        return self._parameters

    def state_dict(self):
        return self._state_dict

    def load_state_dict(self, state_dict, path):
        for k, state in state_dict.items():
            self._feature_nets[k].load_state_dict(state, path)

    def __call__(
        self,
        main_agent_id: str,
        states: Dict[str, np.ndarray],
        actions: Dict[str, np.ndarray],
    ):
        other_agent_ids = set(states.keys()) - {main_agent_id}
        v_feature = self._feature_nets[main_agent_id](
            states[main_agent_id], actions[main_agent_id]
        )
        u_features = list(
            map(lambda x: self._feature_nets[x](states[x], actions[x]), other_agent_ids)
        )
        inner_shape = u_features[0].shape
        u_features = torch.cat(u_features, dim=1).reshape(
            (inner_shape[0], -1, inner_shape[1])
        )

        mean_u = torch.mean(u_features, dim=1)
        offset = v_feature * mean_u.detach()
        offset = torch.sum(offset, dim=1)

        return offset


Transition = namedtuple("Transition", "cur_obs, action, next_obs, reward, done")


class FQLAgent:
    def __init__(
        self,
        agent_id: str,
        agent_obs_spaces,
        agent_act_spaces,
        gamma: float,
        lr: float,
        factorization_module: Factorization,
    ):
        self._obs_space = agent_obs_spaces[agent_id]
        self._act_space = agent_act_spaces[agent_id]
        self._agent_num = len(agent_act_spaces)
        self._agent_id = agent_id

        self._qnet = QNet(self._obs_space, self._act_space)
        self._target_qnet = QNet(self._obs_space, self._act_space)
        self._behavior = QNet(self._obs_space, self._act_space)
        self._factorization = factorization_module

        self._parameters = {
            "behavior": self._behavior.parameters(),
            "q": self._qnet.parameters(),
            "target_q": self._target_qnet.parameters(),
            "factorization": self._factorization.parameters(),
        }

        self._state_dict = {
            "behavior": self._behavior.state_dict(),
            "q": self._qnet.state_dict(),
            "target_q": self._target_qnet.state_dict(),
            "factorization": self._factorization.state_dict(),
        }

        self._lr = lr
        self._gamma = gamma
        self._fql_optimizer = optim.Adam(
            list(self._parameters["q"]) + list(self._parameters["factorization"]), lr=lr
        )
        self._behavior_optimizer = optim.Adam(self._parameters["behavior"], lr=lr)

    @property
    def factorization(self):
        return self._factorization

    @property
    def target_q_net(self):
        return self._target_qnet

    def __call__(self, states, agent_actions):
        return self.forward(states, agent_actions)

    def forward(
        self,
        states: Dict[str, np.ndarray],
        agent_actions: Dict[str, np.ndarray],
    ):
        """Compute the joint Q-values"""

        qvalues = self._qnet(states[self._agent_id])
        offset = self._factorization(self._agent_id, states, agent_actions)
        total_value = (
            torch.sum(qvalues * torch.from_numpy(agent_actions[self._agent_id]).float())
            + offset
        )
        return total_value

    def sync_q(self, tau=1.0):
        """tau=1.0 means hard update, smaller tau means softer"""

        # logger.debug(f"Update target Q network with tau={tau}")
        q_vars = self._parameters["q"]
        tq_vars = self._parameters["target_q"]

        for q_var, tq_var in zip(q_vars, tq_vars):
            tq_var.data.copy_(tau * q_var.data + (1.0 - tau) * tq_var.data)

    def compute_actions(self, state: np.ndarray):
        values = self._behavior(state)
        actions = torch.argmax(values, dim=-1).numpy()
        return actions

    def fql_loss(
        self,
        agent_batches: Dict[str, Transition],
        next_joint_q,
    ):
        # compute joint q
        # collect states and actions
        akeys = list(agent_batches.keys())
        agent_states = {k: agent_batches[k].cur_obs for k in akeys}
        agent_actions = {k: agent_batches[k].action for k in akeys}
        estimated_q = self.forward(agent_states, agent_actions)

        reward = (
            torch.from_numpy(agent_batches[self._agent_id].reward).float().reshape(-1)
        )
        done = torch.from_numpy(agent_batches[self._agent_id].done).float().reshape(-1)
        target_q = reward + (1.0 - done) * self._gamma * next_joint_q
        loss = F.mse_loss(estimated_q, target_q)
        return loss

    def behavior_loss(self, agent_batches: Dict[str, Transition]):
        akeys = list(agent_batches.keys())
        agent_states = {k: agent_batches[k].cur_obs for k in akeys}
        agent_actions = {k: agent_batches[k].action for k in akeys}

        joint_q = self.forward(agent_states, agent_actions)
        # mse loss for behavior model
        b_q = torch.sum(
            self._behavior(agent_states[self._agent_id])
            * torch.from_numpy(agent_actions[self._agent_id]).float(),
            dim=1,
        )
        loss = F.mse_loss(b_q, joint_q)
        return loss

    def train(
        self,
        agent_batches: Dict[str, Transition],
        agent_next_states: Dict[str, np.ndarray],
        agent_next_actions: Dict[str, np.ndarray],
    ):
        interaction_factorization = self._factorization(
            self._agent_id, agent_next_states, agent_next_actions
        )
        next_q_values = torch.sum(
            self._target_qnet(agent_next_states[self._agent_id])
            * torch.from_numpy(agent_next_actions[self._agent_id]).float(),
            dim=-1,
        )
        next_joint_q = next_q_values + interaction_factorization

        logger.debug(f"\tOptimization for FQL ...")
        self._fql_optimizer.zero_grad()
        fql_loss = self.fql_loss(agent_batches, next_joint_q.detach())
        fql_loss.backward()
        self._fql_optimizer.step()

        logger.debug(f"\tOptimization for behaviour policy ...")
        self._behavior_optimizer.zero_grad()
        behavior_loss = self.behavior_loss(agent_batches)
        behavior_loss.backward()
        self._behavior_optimizer.step()

        return {
            "fql_loss": fql_loss.detach().numpy(),
            "behavior_loss": behavior_loss.detach().numpy(),
        }

    def state_dict(self):
        return self._state_dict

    def load_state_dict(self, state_dict, path):
        self._behavior.load_state_dict(state_dict["behavior"], path)
        self._factorization.load_state_dict(state_dict["factorization"], path)
        self._qnet.load_state_dict(state_dict["q"], path)
        self._target_qnet.load_state_dict(state_dict["target_q"], path)


@dataclass
class SampleBatch:
    buffer_size: int
    cur_obs: list = field(default_factory=list)
    action: list = field(default_factory=list)
    next_obs: list = field(default_factory=list)
    reward: list = field(default_factory=list)
    done: list = field(default_factory=list)

    def __post_init__(self):
        self._flag = 0
        self._size = 0

    def update_transition(self, obs, action, next_obs, reward, done):
        if len(self.cur_obs) < self.buffer_size:
            self.cur_obs.append(obs)
            self.action.append(action)
            self.next_obs.append(next_obs)
            self.reward.append(reward)
            self.done.append(done)
        else:
            self.cur_obs[self._flag] = obs
            self.action[self._flag] = action
            self.next_obs[self._flag] = next_obs
            self.reward[self._flag] = reward
            self.done[self._flag] = done

        self._flag = (self._flag + 1) % self.buffer_size
        self._size = min(self._size + 1, self.buffer_size)

    def update_transitions(self, obses, actions, next_obses, rewards, dones):
        for obs, action, next_obs, reward, done in zip(
            obses, actions, next_obses, rewards, dones
        ):
            self.update_transition(obs, action, next_obs, reward, done)

    def cleaned_data(self):
        """Convert list data to numpy.ndarray"""

        data = tuple(
            map(
                lambda x: np.row_stack(x),
                [self.cur_obs, self.action, self.next_obs, self.reward, self.done],
            )
        )
        return data

    @property
    def size(self):
        return self._size

    def sample(self, idxes):
        cleaned_data = self.cleaned_data()
        return Transition(*map(lambda x: x[idxes], cleaned_data))


class SampleBuffer:
    def __init__(self, agent_keys: Sequence[str], buffer_size: int):
        self._agent_buffer = {k: SampleBatch(buffer_size) for k in agent_keys}
        self._buffer_size = buffer_size
        self._size = 0

    def add(self, obs, actions, next_obs, rewards, dones):
        for k, buffer in self._agent_buffer.items():
            buffer.update_transitions(
                obs[k], actions[k], next_obs[k], rewards[k], dones[k]
            )
            self._size = buffer.size

    def sample(self, batch_size) -> Dict[str, Any]:
        idxes = np.random.choice(self._size, batch_size)
        res = dict()
        for k, buffer in self._agent_buffer.items():
            res[k] = buffer.sample(idxes)
        return res

    @property
    def size(self):
        return self._size


def rollout(
    env, agents: Dict[str, FQLAgent], fragment_length: int, buffer: SampleBuffer
):
    env.reset()
    agent_reward = defaultdict(lambda: 0.0)
    agent_live_step = defaultdict(lambda: -1)

    observations, actions, next_observations, rewards, dones, infos = [
        defaultdict(lambda: []) for _ in range(6)
    ]
    for aid in env.agent_iter(max_iter=fragment_length * len(env.possible_agents)):
        observation, reward, done, info = env.last()
        # logger.debug(f"\t [step {i}] agent {aid}: dead or not [{done}]")
        observations[aid].append(observation)
        rewards[aid].append(reward)
        dones[aid].append(done)
        agent_reward[aid] += reward
        agent_live_step[aid] += 1

        if not done:
            action = agents[aid].compute_actions(np.asarray([observation]))[0]
            actions[aid].append(action)
            env.step(action)
        else:
            env.step(None)

    # clean data
    clip_length = min(list(map(lambda x: len(x), observations.values())))
    for k, obs in observations.items():
        next_observations[k] = copy.copy(obs[1:clip_length])
        observations[k] = obs[: clip_length - 1]
        actions[k] = to_one_hot(
            np.asarray(actions[k][:clip_length], dtype=np.int32),
            size=env.action_spaces[k].n,
        )
        dones[k] = dones[k][1:clip_length]
        rewards[k] = rewards[k][1:clip_length]

    buffer.add(observations, actions, next_observations, rewards, dones)

    # we drop the last step
    return {"agent_live": agent_live_step, "agent_reward": agent_reward}


def train(
    buffer: SampleBuffer, agents: Dict[str, FQLAgent], batch_size: int = 64
) -> Dict[str, Any]:
    agent_buffers = buffer.sample(batch_size=batch_size)

    logger.debug(f"Entering training phase (batch_size={batch_size})...")
    agent_next_states = {k: agent_buffers[k].next_obs for k in agent_buffers.keys()}
    agent_next_actions = {
        k: to_one_hot(
            agents[k].compute_actions(agent_next_states[k]),
            env.action_spaces[k].n,
        )
        for k in agent_buffers.keys()
    }

    agent_fql_loss, agent_behavior_loss = dict(), dict()

    for k, agent in agents.items():
        statistic = agents[k].train(
            agent_buffers, agent_next_states, agent_next_actions
        )
        logger.debug(f"\tAgent {k} finished, return: {statistic}")
        agent_fql_loss[k] = statistic["fql_loss"]
        agent_behavior_loss[k] = statistic["behavior_loss"]

    return {"fql_loss": agent_fql_loss, "behavior_loss": agent_behavior_loss}


def save_model(agents: Dict[str, FQLAgent], model_dir: str):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    for agent_id, fql_agent_model in agents.items():
        gen_dir_path = f"{model_dir}/fql_{agent_id}"
        torch.save(fql_agent_model.state_dict(), gen_dir_path)


def load_model(agents: Dict[str, FQLAgent], model_dir: str):
    for agent_id, fql_agent_model in agents.items():
        gen_dir_path = f"{model_dir}/fql_{agent_id}"
        fql_agent_model.load_state_dict(fql_agent_model.state_dict(), gen_dir_path)


def analysis(writer: SummaryWriter, epoch, rollout_info, train_info):
    logger.info(
        f"[Rollout - {epoch}] agent rewards: {tuple(rollout_info['agent_reward'].values())}"
    )

    def record(info_dict):
        for k, v in info_dict.items():
            name = k
            if isinstance(v, dict):
                for subk, subv in v.items():
                    subname = f"{name}/{subk}"
                    writer.add_scalar(subname, subv, epoch)
            else:
                writer.add_scalar(name, v, epoch)

    record(rollout_info)
    record(train_info)


def learn(env, train_config):
    n_round = train_config["n_round"]
    fragment_length = train_config["fragment_length"]
    save_interval = train_config["save_interval"]
    learning_rate = train_config["lr"]
    gamma = train_config["gamma"]
    batch_size = train_config["batch_size"]
    feature_length = train_config["feature_length"]
    buffer_size = train_config["buffer_size"]
    tau = train_config["tau"]
    log_dir = train_config["log_dir"]

    # build agent policy model and their optimizers
    factorization_module = Factorization(
        env.possible_agents, env.observation_spaces, env.action_spaces, feature_length
    )
    agents = {
        k: FQLAgent(
            k,
            env.observation_spaces,
            env.action_spaces,
            gamma,
            learning_rate,
            factorization_module,
        )
        for k in env.possible_agents
    }

    for agent in agents.values():
        agent.sync_q()

    buffer = SampleBuffer(env.possible_agents, buffer_size)
    writer = SummaryWriter(log_dir=f"{log_dir}/run/expground/fql")

    round_iter = 0
    while round_iter < n_round:
        logger.debug(f"Start {round_iter}th training round ...")

        # rollout and evaluate
        rollout_info = rollout(env, agents, fragment_length, buffer)
        logger.info(f"\tcurrent buffer size={buffer.size}")
        # train
        train_info = train(buffer, agents, batch_size)
        # soft sync
        _ = [agent.sync_q(tau) for agent in agents.values()]

        # analysis the rollout and training results
        analysis(writer, round_iter, rollout_info, train_info)

        if (round_iter + 1) % save_interval == 0:
            logger.info("[model saving] save model....")
            save_model(agents, f"{writer.log_dir}/model/{round_iter}")

        round_iter += 1

    if (round_iter + 1) % save_interval != 0:
        logger.info("[model saving] save last model....")
        save_model(agents, f"{writer.log_dir}/model/{round_iter + 1}")


parser = argparse.ArgumentParser("Factorized Q-learning")
parser.add_argument("--log_dir", help="log directory.", type=str, required=True)
parser.add_argument("--lr", help="set learning rate", type=float, default=1e-4)
parser.add_argument("--num_epoch", help="training epochs", type=int, default=1000)
parser.add_argument("--num_runs", help="number of experiments", type=int, default=5)
parser.add_argument("--gamma", help="discounted factor", type=float, default=0.98)
parser.add_argument("--batch_size", help="training batch size", type=int, default=256)
parser.add_argument(
    "--buffer_size", help="size of ReplayBuffer", type=int, default=2000000
)
parser.add_argument(
    "--fragment_length", help="length of an episode", type=int, default=25
)


args = parser.parse_args()

if __name__ == "__main__":
    config = {
        "n_round": 1000,
        "fragment_length": 25,
        "save_interval": 2,
        "lr": args.lr,
        "gamma": 0.98,
        "batch_size": 256,
        "feature_length": 18,
        "buffer_size": 2000000,
        "tau": 0.05,
    }

    env = simple_tag_v2.env(
        num_good=1, num_adversaries=3, num_obstacles=2, max_cycles=25
    )
    learn(env, config)
