import copy
import gymnasium as gym
import itertools
import logging
import numpy as np
import torch
import torch.nn as nn
from typing import Tuple

from collections import namedtuple
from . import sac, utils
from .utils import ReplayMemory, StochasticTrainingAlgorithm
from ..delayed_mdp import SimulatedInteractionLayer

import latency_env.misc.argparser_types as at
from latency_env.misc.argparser import Argument as Arg
from latency_env.misc.argparser import ArgumentList as ArgList

LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


arguments = sac.arguments + ArgList(
    assumed_delay = Arg("--assumed-delay", default=1, type=at.posint,
                        help="Assumed delay in terms of constant time steps."),
)


class BPQLReplayMemory(ReplayMemory):
    """A standard replay buffer."""
    ParameterOrder = ReplayMemory.ParameterOrder + [
        "s_bar", "s_bar_next", "future_state", "future_next_state",
    ]

    def __init__(self, assumed_delay, *args, **kwargs):
        super().__init__(*args, **kwargs)
        flat_s_shape = 1
        flat_a_shape = 1
        for i in self.s_shape: flat_s_shape *= i
        for i in self.a_shape: flat_a_shape *= i
        self.s_bar_shape = (flat_s_shape + (assumed_delay * flat_a_shape),)

        self.rb["s_bar"] = torch.zeros((self.maxsize,) + self.s_bar_shape, requires_grad=False, dtype=self.s_type)
        self.rb["s_bar_next"] = torch.zeros((self.maxsize,) + self.s_bar_shape, requires_grad=False, dtype=self.s_type)
        self.rb["future_state"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)
        self.rb["future_next_state"] = torch.zeros((self.maxsize,) + self.s_shape, requires_grad=False, dtype=self.s_type)


class BPQL(sac.SAC):
    """
    Implementation of BPQL, which is an extension of SAC
    """
    def __init__(self, pi: nn.Module, q1: nn.Module, q2: nn.Module, env: SimulatedInteractionLayer, args: "arguments"):
        super().__init__(pi=pi, q1=q1, q2=q2, env=env, args=args)

        #assert isinstance(env, SimulatedInteractionLayer)

        # Override replay buffer with BPQL one
        self.replay = BPQLReplayMemory(
            maxsize=self.p.replay_size, env=self.env,
            assumed_delay=self.p.assumed_delay,
        )

    def s_bar(self, s, action_trace):
        """Construct the extended state s_bar (notation from paper)"""
        assert len(action_trace) >= self.p.assumed_delay, f"{action_trace}"
        return np.concatenate([s.flatten()] + [a.flatten() for a in action_trace[-self.p.assumed_delay:]])

    @torch.no_grad()
    def sample_action(self, state_action_info: Tuple["State", "ActionTrace"], **kwargs):
        """
        Convert the delayed state into a sequence of observed states and actions.
        """
        return super().sample_action(
            self.s_bar(*state_action_info),
            **kwargs,
        )

    def evaluate(self, trajectory_length=None, env=None, **kwargs):
        """
        Evaluates the BPQL policy.
        """
        if env is None:
            env = self.env
        if trajectory_length is None:
            trajectory_length = self.p.max_trajectory

        action_trace = []

        def fn_reset():
            nonlocal action_trace
            s, info = env.reset()
            action_trace = [env.default_action.copy() for _ in range(self.p.assumed_delay)]
            return (s, action_trace), info

        def fn_step(a):
            nonlocal action_trace
            action_trace.append(a)
            s, r, term, trunc, info = env.step(a)
            return ((s, action_trace), r, term, trunc, info)

        return super().evaluate(trajectory_length=trajectory_length, env=env,
                                fn_reset=fn_reset, fn_step=fn_step, **kwargs)


    def train(self, n_episodes=None, n_steps=None, follow_policy=True):
        """
        Trains the networks contained within with BPQL.
        """
        for ep in self.progress(n_episodes=n_episodes, n_steps=n_steps):
            r_total = 0.0
            (s, _), done = self.env.reset(seed=ep), False

            obs_stack = []
            act_trace = [self.env.default_action.copy() for _ in range(self.p.assumed_delay)]
            for step in self.progress_steps(limit=self.p.max_trajectory):
                if done:
                    break

                if not self.update_before_sampling:
                    # Take a step
                    a, info = self.sample_action((s, act_trace), random=(not follow_policy), with_info=True)
                    s_next, r, terminated, truncated, _ = self.env.step(a)
                    done = terminated or truncated

                    obs_stack.append((s, a, r, s_next, terminated))
                    act_trace.append(a)
                    s = s_next
                    r_total += r

                    ait = np.nditer(a, flags=["multi_index"])
                    for aval in ait:
                        idx = ("/" + ",".join(str(i) for i in ait.multi_index)) if ait.multi_index != () else ""
                        self.record_iteration_value("actor/action" + idx, aval)
                    if "std" in info:
                        stdit = np.nditer(info["std"], flags=["multi_index"])
                        for stdval in stdit:
                            idx = ("/" + ",".join(str(i) for i in stdit.multi_index)) if stdit.multi_index != () else ""
                            self.record_iteration_value("actor/action_std" + idx, stdval)

                    # Only add to action buffer once we know what the actual state was
                    if len(obs_stack) > self.p.assumed_delay:
                        # Construct the extended states:
                        i = len(obs_stack) - (self.p.assumed_delay+1)
                        rs, ra, _, rs_next, _ = obs_stack[i]
                        rs_future, _, rr, rsn_future, rterm = obs_stack[-1]

                        #LOG.debug(f"i={i}, len(obs_stack)={len(obs_stack)}, len(act_trace)={len(act_trace)}")

                        rs_bar = self.s_bar(rs, act_trace[:i + self.p.assumed_delay])
                        rs_bar_next = self.s_bar(rs_next, act_trace[:(i + 1) + self.p.assumed_delay])

                        # Sanity check that data matches what we would expect
                        #emb_ra = rs_future[self.env.raw_obs_len:(self.env.raw_obs_len + self.env.raw_act_len)]
                        #assert (emb_ra == ra).all(), f"emb_ra: {emb_ra}, ra: {ra}, rs: {rs}, rs_future: {rs_future}"
                        self.replay.add(rs, ra, rr, rs_next, rterm, rs_bar, rs_bar_next, rs_future, rsn_future)

                # Minibatch network updates
                self.steps_since_update -= 1
                if (self.steps_since_update <= 0 or self.update_before_sampling) and len(self.replay) > 0:
                    self.update_before_sampling = False
                    self.steps_since_update = self.p.update_delay
                    losses_q = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    losses_pi = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    losses_alpha = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    for update_step in range(self.p.update_iterations):
                        # Sample batch for updates
                        bs, ba, br, bsn, bd, bs_bar, bsn_bar, bsf, bsnf = self.replay.sample(self.p.batch_size, device=self.device)

                        # Extract the future raw observation without any of the associated actions
                        #bsf = bsf[:, :self.env.raw_obs_len]
                        #bsnf = bsnf[:, :self.env.raw_obs_len]

                        # Pre-fetch the alpha for later
                        with torch.no_grad():
                            alpha = torch.exp(self.log_alpha)

                        # Q update
                        self.q_opt.zero_grad()
                        q1val = self.q1(bsf, ba)
                        if self.p.use_double_q:
                            q2val = self.q2(bsf, ba)
                        #assert q1val.shape == q2val.shape == br.shape == bd.shape, f"q1val.shape: {q1val.shape}, q2val.shape: {q2val.shape}, br.shape: {br.shape}, bd.shape: {bd.shape}"
                        with torch.no_grad():
                            if self.p.use_double_q:
                                an, _, logp_an, _ = self.pi(bsn_bar)
                                qval_targ = torch.min(self.q1_targ(bsnf, an), self.q2_targ(bsnf, an))
                            else:
                                an, _, logp_an, _ = self.pi_targ(bsn_bar)
                                qval_targ = self.q1_targ(bsnf, an)
                            ys = br + self.p.discount * (1 - bd) * (qval_targ - alpha * logp_an)

                        loss_q1 = ((q1val - ys)**2).mean()
                        if self.p.use_double_q:
                            loss_q2 = ((q2val - ys)**2).mean()
                            loss_q = loss_q1 + loss_q2
                        else:
                            loss_q = loss_q1
                        loss_q.backward()
                        if self.p.norm_clip_q is not None:
                            torch.nn.utils.clip_grad_norm_(self.q.parameters(), self.p.norm_clip_q)

                        self.q_opt.step()

                        # pi (policy) update
                        self.pi_opt.zero_grad()

                        a, _, logp_a, _ = self.pi(bs_bar)
                        qval = self.q1(bsf, a)
                        if self.p.use_double_q:
                            qval = torch.min(qval, self.q2(bsf, a))

                        loss_pi = -(qval - alpha * logp_a).mean()
                        loss_pi.backward()
                        if self.p.norm_clip_pi is not None:
                            torch.nn.utils.clip_grad_norm_(self.pi.parameters(), self.p.norm_clip_pi)

                        self.pi_opt.step()

                        # alpha (temperature) update
                        self.alpha_opt.zero_grad()

                        # J(a) = E[-a logp(a|s) - a*H]
                        loss_alpha = -1.0 * (
                            torch.exp(self.log_alpha) * (
                                logp_a.mean() +
                                self.p.entropy_threshold
                            ).clone().detach().requires_grad_(False)
                        )
                        loss_alpha.backward()

                        self.alpha_opt.step()

                        losses_q[update_step] = loss_q.detach().cpu().numpy()
                        losses_pi[update_step] = loss_pi.detach().cpu().numpy()
                        losses_alpha[update_step] = loss_alpha.detach().cpu().numpy()

                        # Polyak (soft target) updates
                        self.updates_since_polyak -= 1
                        if self.updates_since_polyak <= 0:
                            self.updates_since_polyak = self.p.polyak_delay
                            with torch.no_grad():
                                for train, target in [(self.pi, self.pi_targ), (self.q1, self.q1_targ), (self.q2, self.q2_targ)]:
                                    for p_train, p_targ in zip(train.parameters(), target.parameters()):
                                        p_targ.data.mul_(1 - self.p.polyak)
                                        p_targ.data.add_(self.p.polyak * p_train.data)

                    self.record_iteration_value("critic/loss", losses_q.mean())
                    self.record_iteration_value("actor/loss", losses_pi.mean())
                    self.record_iteration_value("temperature/loss", losses_alpha.mean())
                    self.record_iteration_value("temperature/value", torch.exp(self.log_alpha).item())

            self.record_episode_total_reward(r_total)
