import copy
import gymnasium as gym
import itertools
import logging
import numpy as np
import torch
import torch.nn as nn

from .utils import ReplayMemory, StochasticTrainingAlgorithm
from . import utils

import latency_env.misc.argparser_types as at
from latency_env.misc.argparser import Argument as Arg
from latency_env.misc.argparser import ArgumentList as ArgList

LOG = logging.getLogger(__name__)
LOG.addHandler(logging.NullHandler())


arguments = utils.trainer_arguments + ArgList(
    initial_alpha = Arg("--initial-alpha", default=0.2, type=at.posfloat,
                        help="Initial temperature, a.k.a. entropy regularization coefficient."),
    use_double_q = Arg("--use-double-q", default=True, type=at.boolean,
                       help="Whether to use double Q functions when learning."),
    lr_pi = Arg("--lr-pi", default=1e-5, type=at.posfloat,
                help="Actor learning rate."),
    lr_q = Arg("--lr-q", default=1e-4, type=at.posfloat,
               help="Critic learning rate."),
    lr_alpha = Arg("--lr-alpha", default=0.0, type=at.nonnegfloat,
                   help="Temperature learning rate."),
    optim_pi = Arg("--optim-pi", default="adam", type=str.lower, choices=utils.OPTIMIZERS.keys(),
                   help="Actor network optimizer."),
    optim_q = Arg("--optim-q", default="adam", type=str.lower, choices=utils.OPTIMIZERS.keys(),
                  help="Critic network optimizer."),
    optim_alpha = Arg("--optim-alpha", default="adam", type=str.lower, choices=utils.OPTIMIZERS.keys(),
                      help="Temperature network optimizer."),
    norm_clip_pi = Arg("--norm-clip-pi", default=None, type=float,
                       help="Actor gradient norm clipping."),
    norm_clip_q = Arg("--norm-clip-q", default=None, type=float,
                      help="Critic gradient norm clipping."),
    entropy_threshold = Arg("--entropy-threshold", default=-1.0, type=at.negfloat,
                            help="Temperature entropy optimization threshold."),
    discount = Arg("--discount", default=0.99, type=at.floatr_o0_1o,
                   help="Discount for future rewards."),
    replay_size = Arg("--replay-size", default=10_000, type=at.posint,
                      help="Maximum size of replay memory."),
    batch_size = Arg("--batch-size", default=64, type=at.posint,
                     help="Size of minibatches when performing network updates."),
    update_delay = Arg("--update-delay", default=1, type=at.posint,
                       help="How many environment steps to wait before performing minibatch updates."),
    update_iterations = Arg("--update-iterations", default=1, type=at.posint,
                            help="How many minibatch updates to perform in an update group."),
    update_first = Arg("--update-first", default=False, type=bool,
                       help="Perform minibatch updates before sampling from the policy."),
    polyak = Arg("--polyak", default=1e-3, type=float,
                 help="Polyak factor for soft updates."),
    polyak_delay = Arg("--polyak-delay", default=1, type=at.posint,
                       help="How many minibatch steps to wait before performing the target update."),
    max_trajectory = Arg("--max-trajectory", default=10000, type=at.posint,
                         help="Maximum trajectory length."),
)


class SAC(StochasticTrainingAlgorithm):
    """
    Implementation of Soft Actor-Critic

    This is a stochastic policy gradient algorithm, so the policy pi must
    return a tuple of (action, logprob(action|pi)), i.e. the action(s) and log
    probability of taking that action(s) under this policy.
    """
    def __init__(self, pi: nn.Module, q1: nn.Module, q2: nn.Module, env: gym.core.Env, args: "arguments"):
        super().__init__(env=env, args=args)
        self.env = env
        self.p = args

        # Setup networks
        self.pi = pi
        self.pi_targ = copy.deepcopy(self.pi)
        self.q1 = q1
        self.q1_targ = copy.deepcopy(self.q1)
        self.q2 = q2
        self.q2_targ = copy.deepcopy(self.q2)

        # Alpha parameter
        self.log_alpha = torch.nn.Parameter(torch.tensor(
            np.log(self.p.initial_alpha),
            dtype=torch.float32,
            device=self.device,
        ))

        # Move everything to proper device before creating optimizers
        self.pi = self.pi.to(self.device)
        self.pi_targ = self.pi_targ.to(self.device)
        self.q1 = self.q1.to(self.device)
        self.q1_targ = self.q1_targ.to(self.device)
        self.q2 = self.q2.to(self.device)
        self.q2_targ = self.q2_targ.to(self.device)

        # Setup optimizers
        self.pi_opt = utils.make_optimizer(self.pi, optim=args.optim_pi, lr=args.lr_pi)
        self.q_opt = utils.make_optimizer(self.q1, self.q2, optim=args.optim_q, lr=args.lr_q)
        self.alpha_opt = utils.make_optimizer(self.log_alpha, optim=args.optim_alpha, lr=args.lr_alpha)

        # Setup replay buffer
        self.replay = ReplayMemory(maxsize=self.p.replay_size, env=self.env)

        # State on how many steps since we last performed the polyak updates
        self.update_before_sampling = self.p.update_first
        self.steps_since_update = self.p.update_delay
        self.updates_since_polyak = self.p.polyak_delay

    def evalcopy(self):
        blob = super().evalcopy()
        blob["pi"] = copy.deepcopy(self.pi).to("cpu")
        blob["q1"] = copy.deepcopy(self.q1).to("cpu")
        blob["q2"] = copy.deepcopy(self.q2).to("cpu")
        blob["env"] = self.env
        blob["p"] = self.p
        return blob

    @classmethod
    def from_evalcopy(obj, blob, device=None):
        if device is not None:
            blob = copy.copy(blob)
            blob["p"] = copy.copy(blob["p"])
            blob["p"].device = device
        return obj(
            pi=blob["pi"],
            q1=blob["q1"],
            q2=blob["q2"],
            env=blob["env"],
            args=blob["p"],
        )

    def modules(self):
        return (self.pi, self.pi_targ, self.q1, self.q1_targ, self.q2, self.q2_targ)

    def evaluate(self, trajectory_length=None, **kwargs):
        if trajectory_length is None:
            trajectory_length = self.p.max_trajectory

        return super().evaluate(trajectory_length=trajectory_length, **kwargs)

    def train(self, n_episodes=None, n_steps=None, follow_policy=True):
        """Trains the networks contained within with DDPG."""
        for ep in self.progress(n_episodes=n_episodes, n_steps=n_steps):
            r_total = 0.0
            (s, _), done = self.env.reset(seed=ep), False
            for step in self.progress_steps(limit=self.p.max_trajectory):
                if done:
                    break

                if not self.update_before_sampling:
                    # Take a step
                    a, info = self.sample_action(s, random=(not follow_policy), with_info=True)
                    s_next, r, terminated, truncated, _ = self.env.step(a)
                    done = terminated or truncated

                    self.replay.add(s, a, r, s_next, terminated)
                    s = s_next

                    r_total += r

                    ait = np.nditer(a, flags=["multi_index"])
                    for aval in ait:
                        idx = ("/" + ",".join(str(i) for i in ait.multi_index)) if ait.multi_index != () else ""
                        self.record_iteration_value("actor/action" + idx, aval)
                    if "std" in info:
                        stdit = np.nditer(info["std"], flags=["multi_index"])
                        for stdval in stdit:
                            idx = ("/" + ",".join(str(i) for i in stdit.multi_index)) if stdit.multi_index != () else ""
                            self.record_iteration_value("actor/action_std" + idx, stdval)

                # Minibatch network updates
                self.steps_since_update -= 1
                if self.steps_since_update <= 0 or self.update_before_sampling:
                    self.update_before_sampling = False
                    self.steps_since_update = self.p.update_delay
                    losses_q = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    losses_pi = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    losses_alpha = np.zeros((self.p.update_iterations,), dtype=np.float32)
                    for update_step in range(self.p.update_iterations):
                        # Sample batch for updates
                        bs, ba, br, bsn, bd = self.replay.sample(self.p.batch_size, device=self.device)

                        # Pre-fetch the alpha for later
                        with torch.no_grad():
                            alpha = torch.exp(self.log_alpha)

                        # Q update
                        self.q_opt.zero_grad()
                        q1val = self.q1(bs, ba)
                        if self.p.use_double_q:
                            q2val = self.q2(bs, ba)
                        #assert q1val.shape == q2val.shape == br.shape == bd.shape, f"q1val.shape: {q1val.shape}, q2val.shape: {q2val.shape}, br.shape: {br.shape}, bd.shape: {bd.shape}"
                        with torch.no_grad():
                            if self.p.use_double_q:
                                an, _, logp_an, _ = self.pi(bsn)
                                qval_targ = torch.min(self.q1_targ(bsn, an), self.q2_targ(bsn, an))
                            else:
                                an, _, logp_an, _ = self.pi_targ(bsn)
                                qval_targ = self.q1_targ(bsn, an)
                            ys = br + self.p.discount * (1 - bd) * (qval_targ - alpha * logp_an)

                        loss_q1 = ((q1val - ys)**2).mean()
                        if self.p.use_double_q:
                            loss_q2 = ((q2val - ys)**2).mean()
                            loss_q = loss_q1 + loss_q2
                        else:
                            loss_q = loss_q1
                        loss_q.backward()
                        if self.p.norm_clip_q is not None:
                            torch.nn.utils.clip_grad_norm_(self.q.parameters(), self.p.norm_clip_q)

                        self.q_opt.step()

                        # pi (policy) update
                        self.pi_opt.zero_grad()

                        a, _, logp_a, _ = self.pi(bs)
                        qval = self.q1(bs, a)
                        if self.p.use_double_q:
                            qval = torch.min(qval, self.q2(bs, a))

                        loss_pi = -(qval - alpha * logp_a).mean()
                        loss_pi.backward()
                        if self.p.norm_clip_pi is not None:
                            torch.nn.utils.clip_grad_norm_(self.pi.parameters(), self.p.norm_clip_pi)

                        self.pi_opt.step()

                        # alpha (temperature) update
                        self.alpha_opt.zero_grad()

                        # J(a) = E[-a logp(a|s) - a*H]
                        loss_alpha = -1.0 * (
                            torch.exp(self.log_alpha) * (
                                logp_a.mean() +
                                self.p.entropy_threshold
                            ).clone().detach().requires_grad_(False)
                        )
                        loss_alpha.backward()

                        self.alpha_opt.step()

                        losses_q[update_step] = loss_q.detach().cpu().numpy()
                        losses_pi[update_step] = loss_pi.detach().cpu().numpy()
                        losses_alpha[update_step] = loss_alpha.detach().cpu().numpy()

                        # Polyak (soft target) updates
                        self.updates_since_polyak -= 1
                        if self.updates_since_polyak <= 0:
                            self.updates_since_polyak = self.p.polyak_delay
                            with torch.no_grad():
                                for train, target in [(self.pi, self.pi_targ), (self.q1, self.q1_targ), (self.q2, self.q2_targ)]:
                                    for p_train, p_targ in zip(train.parameters(), target.parameters()):
                                        p_targ.data.mul_(1 - self.p.polyak)
                                        p_targ.data.add_(self.p.polyak * p_train.data)

                    self.record_iteration_value("critic/loss", losses_q.mean())
                    self.record_iteration_value("actor/loss", losses_pi.mean())
                    self.record_iteration_value("temperature/loss", losses_alpha.mean())
                    self.record_iteration_value("temperature/value", torch.exp(self.log_alpha).item())

            self.record_episode_total_reward(r_total)
