import numpy as np
import torch

import global_context
from garage import TrajectoryBatch
from garagei import log_performance_ex
from iod import recurrent_sac_utils
from iod.iod import IOD
import copy

from iod.utils import (
    get_torch_concat_obs,
    FigManager,
    get_option_colors,
    record_video,
    draw_2d_gaussians,
)


class RecurrentMETRA(IOD):
    def __init__(
        self,
        *,
        qf1,
        qf2,
        log_alpha,
        tau,
        scale_reward,
        target_coef,
        replay_buffer,
        min_buffer_size,
        inner,
        num_alt_samples,
        split_group,
        dual_reg,
        dual_slack,
        dual_dist,
        pixel_shape=None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.qf1 = qf1.to(self.device)
        self.qf2 = qf2.to(self.device)

        self.target_qf1 = copy.deepcopy(self.qf1)
        self.target_qf2 = copy.deepcopy(self.qf2)

        self.log_alpha = log_alpha.to(self.device)

        self.param_modules.update(
            qf1=self.qf1,
            qf2=self.qf2,
            log_alpha=self.log_alpha,
        )

        self.tau = tau

        self.replay_buffer = replay_buffer
        self.min_buffer_size = min_buffer_size
        self.inner = inner

        self.dual_reg = dual_reg
        self.dual_slack = dual_slack
        self.dual_dist = dual_dist

        self.num_alt_samples = num_alt_samples
        self.split_group = split_group

        self._reward_scale_factor = scale_reward
        self._target_entropy = (
            -np.prod(self._env_spec.action_space.shape).item() / 2.0 * target_coef
        )

        self.pixel_shape = pixel_shape

        assert self._trans_optimization_epochs is not None

    @property
    def policy(self):
        return {
            "option_policy": self.option_policy,
        }

    def _get_concat_obs(self, obs, option, dim=1):
        return get_torch_concat_obs(obs, option, dim=dim)

    def _get_train_trajectories_kwargs(self, runner):
        if self.discrete:
            #### start debug
            extras = self._generate_option_extras(
                np.eye(self.dim_option)[
                    np.random.randint(0, self.dim_option, runner._train_args.batch_size)
                ]
            )
            # extras = self._generate_option_extras(np.eye(self.dim_option)[np.random.randint(0, 1, runner._train_args.batch_size)])

        else:
            # random_options = np.zeros((runner._train_args.batch_size, self.dim_option))
            random_options = np.random.randn(
                runner._train_args.batch_size, self.dim_option
            )
            if self.unit_length:
                random_options /= np.linalg.norm(random_options, axis=-1, keepdims=True)
            extras = self._generate_option_extras(random_options)
            #### debug

        return dict(
            extras=extras,
            sampler_key="option_policy",
        )

    def _flatten_data(self, data):
        epoch_data = {}
        for key, value in data.items():
            epoch_data[key] = torch.tensor(
                np.concatenate(value, axis=0), dtype=torch.float32, device=self.device
            )
        return epoch_data

    def _update_replay_buffer(self, data):
        if self.replay_buffer is not None:
            # Add paths to the replay buffer
            for i in range(len(data["actions"])):
                path = {}
                for key in data.keys():
                    cur_list = data[key][i]
                    if cur_list.ndim == 1:
                        cur_list = cur_list[..., np.newaxis]
                    path[key] = cur_list
                self.replay_buffer.add_path(path)

    # def _sample_replay_buffer(self):
    #     # samples = self.replay_buffer.sample_trajectories(self._trans_minibatch_size)
    #     samples = self.replay_buffer.sample_transitions(self._trans_minibatch_size)
    #     data = {}
    #     for key, value in samples.items():
    #         if value.shape[1] == 1 and 'option' not in key: # (# traj), (# len_path), n_dim
    #             value = np.squeeze(value, axis=1) # flatten
    #         data[key] = torch.from_numpy(value).float().to(self.device)
    #     return data
    def _sample_replay_buffer(self):
        samples = self.replay_buffer.sample_trajectories(self._trans_minibatch_size)
        # use 10 horizons for losses
        horizon = samples["obs"].shape[1]
        rand_idx = np.random.randint(0, horizon - 10)
        for key, value in samples.items():
            samples[key] = value[:, rand_idx : rand_idx + 10, ...]

        # samples = self.replay_buffer.sample_transitions(self._trans_minibatch_size)
        #### debug
        # samples = self.replay_buffer.sample_transitions(self._trans_minibatch_size)
        # for key, value in samples.items():
        # samples[key] = value[:, None, ...]
        ###

        data = {}
        for key, value in samples.items():
            if (
                value.shape[2] == 1 and "option" not in key
            ):  # (# traj), (# len_path), n_dim
                value = np.squeeze(value, axis=2)  # flatten
            data[key] = torch.from_numpy(value).float().to(self.device)
        return data

    def _train_once_inner(self, path_data):
        self._update_replay_buffer(path_data)

        epoch_data = self._flatten_data(path_data)

        tensors = self._train_components(epoch_data)

        return tensors

    def _train_components(self, epoch_data):
        if (
            self.replay_buffer is not None
            and self.replay_buffer.n_transitions_stored < self.min_buffer_size
        ):
            return {}

        for _ in range(self._trans_optimization_epochs):
            tensors = {}  # to keep track of losses

            if self.replay_buffer is None:
                v = self._get_mini_tensors(epoch_data)
            else:
                v = self._sample_replay_buffer()

            # #### debug
            # mask = torch.randint(0, v['obs'].shape[1], (v['obs'].shape[0],), device=self.device)
            # for key, value in v.items():
            #     # check dim
            #     if value.ndim == 2:
            #         v[key] = value.gather(1, mask.unsqueeze(1))
            #     elif value.ndim == 3:
            #         v[key] = value.gather(1, mask.unsqueeze(1).unsqueeze(2).repeat(1, 1, value.shape[2]))
            #     else:
            #         raise NotImplementedError
            # #### end debug

            # #### debug
            # self._optimize_te(tensors, v)
            # self._update_rewards(tensors, v)
            # ####
            self._optimize_op(tensors, v)

        return tensors

    # def _optimize_te(self, tensors, internal_vars):
    #     self._update_loss_te(tensors, internal_vars)

    #     self._gradient_descent(
    #         tensors['LossTe'],
    #         optimizer_keys=['traj_encoder'],
    #     )

    #     if self.dual_reg:
    #         self._update_loss_dual_lam(tensors, internal_vars)
    #         self._gradient_descent(
    #             tensors['LossDualLam'],
    #             optimizer_keys=['dual_lam'],
    #         )
    #         if self.dual_dist == 's2_from_s':
    #             self._gradient_descent(
    #                 tensors['LossDp'],
    #                 optimizer_keys=['dist_predictor'],
    #             )

    def _optimize_te(self, tensors, internal_vars):
        self._update_loss_te(tensors, internal_vars)

        self._gradient_descent(
            tensors["LossTe"],
            optimizer_keys=["traj_encoder"],
        )

        if self.dual_reg:
            self._update_loss_dual_lam(tensors, internal_vars)
            self._gradient_descent(
                tensors["LossDualLam"],
                optimizer_keys=["dual_lam"],
            )
            if self.dual_dist == "s2_from_s":
                self._gradient_descent(
                    tensors["LossDp"],
                    optimizer_keys=["dist_predictor"],
                )

    def _optimize_op(self, tensors, internal_vars):
        self._update_loss_qf(tensors, internal_vars)

        self._gradient_descent(
            tensors["LossQf1"] + tensors["LossQf2"],
            optimizer_keys=["qf"],
        )

        self._update_loss_op(tensors, internal_vars)
        self._gradient_descent(
            tensors["LossSacp"],
            optimizer_keys=["option_policy"],
        )

        self._update_loss_alpha(tensors, internal_vars)
        self._gradient_descent(
            tensors["LossAlpha"],
            optimizer_keys=["log_alpha"],
        )

        recurrent_sac_utils.update_targets(self)

    def _update_rewards(self, tensors, v):
        obs = v["obs"]
        next_obs = v["next_obs"]
        full_obs = torch.cat([obs[:, 0, :].unsqueeze(1), next_obs], dim=1)

        if self.inner:
            z = self.traj_encoder(full_obs)[
                0
            ].mean  # traj_encoder returns (dist, hidden_states)
            cur_z = z[:, :-1, :]
            next_z = z[:, 1:, :]

            target_z = next_z - cur_z

            if self.discrete:
                masks = (
                    (v["options"] - v["options"].mean(dim=2, keepdim=True))
                    * self.dim_option
                    / (self.dim_option - 1 if self.dim_option != 1 else 1)
                )
                rewards = (target_z * masks).sum(dim=2)
            else:
                inner = (target_z * v["options"]).sum(dim=2)
                rewards = inner

            # For dual objectives
            v.update(
                {
                    "cur_z": cur_z,
                    "next_z": next_z,
                }
            )
        else:
            target_dists = self.traj_encoder(full_obs)[0]

            if self.discrete:
                logits = target_dists.mean[:, 1:]
                rewards = -torch.nn.functional.cross_entropy(
                    logits, v["options"].argmax(dim=2), reduction="none"
                )
            else:
                rewards = target_dists.log_prob(
                    torch.cat(v["options"][:, 0, :].unsqueeze(1), v["options"])
                )[:, 1:]

        tensors.update(
            {
                "PureRewardMean": rewards.mean(),
                "PureRewardStd": rewards.std(),
            }
        )

        v["rewards"] = rewards

    def _update_loss_te(self, tensors, v):
        self._update_rewards(tensors, v)
        rewards = v["rewards"]

        obs = v["obs"]
        next_obs = v["next_obs"]

        if self.dual_dist == "s2_from_s":
            s2_dist = self.dist_predictor(obs)
            loss_dp = -s2_dist.log_prob(next_obs - obs).mean()
            tensors.update(
                {
                    "LossDp": loss_dp,
                }
            )

        if self.dual_reg:
            dual_lam = self.dual_lam.param.exp()
            x = obs
            y = next_obs
            phi_x = v["cur_z"]
            phi_y = v["next_z"]

            if self.dual_dist == "l2":
                raise NotImplementedError
                cst_dist = torch.square(y - x).mean(dim=1)
            elif self.dual_dist == "one":
                cst_dist = torch.ones_like(x[:, :, 0])
            elif self.dual_dist == "s2_from_s":
                raise NotImplementedError
                s2_dist = self.dist_predictor(obs)
                s2_dist_mean = s2_dist.mean
                s2_dist_std = s2_dist.stddev
                scaling_factor = 1.0 / s2_dist_std
                geo_mean = torch.exp(
                    torch.log(scaling_factor).mean(dim=1, keepdim=True)
                )
                normalized_scaling_factor = (scaling_factor / geo_mean) ** 2
                cst_dist = torch.mean(
                    torch.square((y - x) - s2_dist_mean) * normalized_scaling_factor,
                    dim=1,
                )

                tensors.update(
                    {
                        "ScalingFactor": scaling_factor.mean(dim=0),
                        "NormalizedScalingFactor": normalized_scaling_factor.mean(
                            dim=0
                        ),
                    }
                )
            else:
                raise NotImplementedError

            cst_penalty = cst_dist - torch.square(phi_y - phi_x).mean(dim=2)
            cst_penalty = torch.clamp(cst_penalty, max=self.dual_slack)
            te_obj = rewards + dual_lam.detach() * cst_penalty
            # [batch_size, horizon]

            v.update({"cst_penalty": cst_penalty})
            tensors.update(
                {
                    "DualCstPenalty": cst_penalty.mean(),
                }
            )
        else:
            te_obj = rewards

        assert te_obj.ndim == 2
        loss_te = -te_obj.mean()

        tensors.update(
            {
                "TeObjMean": te_obj.mean(),
                "LossTe": loss_te,
            }
        )

    def _update_loss_dual_lam(self, tensors, v):
        log_dual_lam = self.dual_lam.param
        dual_lam = log_dual_lam.exp()
        loss_dual_lam = log_dual_lam * (v["cst_penalty"].detach()).mean()

        tensors.update(
            {
                "DualLam": dual_lam,
                "LossDualLam": loss_dual_lam,
            }
        )

    def _update_loss_qf(self, tensors, v):
        processed_cat_obs = self._get_concat_obs(
            self.option_policy.process_observations(v["obs"]), v["options"], dim=2
        )
        next_processed_cat_obs = self._get_concat_obs(
            self.option_policy.process_observations(v["next_obs"]),
            v["next_options"],
            dim=2,
        )

        recurrent_sac_utils.update_loss_qf(
            self,
            tensors,
            v,
            obs=processed_cat_obs,
            actions=v["actions"],
            next_obs=next_processed_cat_obs,
            dones=v["dones"],
            rewards=v["rewards"] * self._reward_scale_factor,
            policy=self.option_policy,
        )

        v.update(
            {
                "processed_cat_obs": processed_cat_obs,
                "next_processed_cat_obs": next_processed_cat_obs,
            }
        )

    def _update_loss_op(self, tensors, v):
        processed_cat_obs = self._get_concat_obs(
            self.option_policy.process_observations(v["obs"]), v["options"], dim=2
        )
        recurrent_sac_utils.update_loss_sacp(
            self,
            tensors,
            v,
            obs=processed_cat_obs,
            policy=self.option_policy,
        )

    def _update_loss_alpha(self, tensors, v):
        recurrent_sac_utils.update_loss_alpha(
            self,
            tensors,
            v,
        )

    def _evaluate_policy(self, runner):
        if self.discrete:
            eye_options = np.eye(self.dim_option)
            random_options = []
            colors = []
            for i in range(self.dim_option):
                num_trajs_per_option = (
                    self.num_random_trajectories // self.dim_option
                    + (i < self.num_random_trajectories % self.dim_option)
                )
                for _ in range(num_trajs_per_option):
                    # #### start debug
                    # random_options.append(eye_options[0])
                    # #### debug

                    random_options.append(eye_options[i])
                    colors.append(i)
            random_options = np.array(random_options)
            colors = np.array(colors)
            num_evals = len(random_options)
            from matplotlib import cm

            cmap = "tab10" if self.dim_option <= 10 else "tab20"
            random_option_colors = []
            for i in range(num_evals):
                random_option_colors.extend([cm.get_cmap(cmap)(colors[i])[:3]])
            random_option_colors = np.array(random_option_colors)
        else:
            random_options = np.random.randn(
                self.num_random_trajectories, self.dim_option
            )
            if self.unit_length:
                random_options = random_options / np.linalg.norm(
                    random_options, axis=1, keepdims=True
                )
            random_option_colors = get_option_colors(random_options * 4)
        random_trajectories = self._get_trajectories(
            runner,
            sampler_key="option_policy",
            extras=self._generate_option_extras(random_options),
            worker_update=dict(
                _render=False,
                _deterministic_policy=True,
            ),
            env_update=dict(_action_noise_std=None),
        )

        with FigManager(runner, "TrajPlot_RandomZ") as fm:
            runner._env.render_trajectories(
                random_trajectories, random_option_colors, self.eval_plot_axis, fm.ax
            )

        data = self.process_samples(random_trajectories)
        obs = torch.stack([torch.from_numpy(ob).to(self.device) for ob in data["obs"]])

        option_dists = self.traj_encoder(obs)[0]

        option_means = option_dists.mean[:, -1, :].detach().cpu().numpy()
        if self.inner:
            option_stddevs = torch.ones_like(
                option_dists.stddev[:, -1, :].detach().cpu()
            ).numpy()
        else:
            option_stddevs = option_dists.stddev[:, -1, :].detach().cpu().numpy()
        option_samples = option_dists.mean[:, -1, :].detach().cpu().numpy()

        option_colors = random_option_colors

        with FigManager(runner, f"PhiPlot") as fm:
            draw_2d_gaussians(option_means, option_stddevs, option_colors, fm.ax)
            draw_2d_gaussians(
                option_samples,
                [[0.03, 0.03]] * len(option_samples),
                option_colors,
                fm.ax,
                fill=True,
                use_adaptive_axis=True,
            )

        eval_option_metrics = {}

        # Videos
        if self.eval_record_video:
            if self.discrete:
                video_options = np.eye(self.dim_option)
                video_options = video_options.repeat(self.num_video_repeats, axis=0)
            else:
                if self.dim_option == 2:
                    radius = 1.0 if self.unit_length else 1.5
                    video_options = []
                    for angle in [3, 2, 1, 4]:
                        video_options.append(
                            [
                                radius * np.cos(angle * np.pi / 4),
                                radius * np.sin(angle * np.pi / 4),
                            ]
                        )
                    video_options.append([0, 0])
                    for angle in [0, 5, 6, 7]:
                        video_options.append(
                            [
                                radius * np.cos(angle * np.pi / 4),
                                radius * np.sin(angle * np.pi / 4),
                            ]
                        )
                    video_options = np.array(video_options)
                else:
                    video_options = np.random.randn(9, self.dim_option)
                    if self.unit_length:
                        video_options = video_options / np.linalg.norm(
                            video_options, axis=1, keepdims=True
                        )
                video_options = video_options.repeat(self.num_video_repeats, axis=0)
            video_trajectories = self._get_trajectories(
                runner,
                sampler_key="local_option_policy",
                extras=self._generate_option_extras(video_options),
                worker_update=dict(
                    _render=True,
                    _deterministic_policy=True,
                ),
            )
            record_video(
                runner,
                "Video_RandomZ",
                video_trajectories,
                skip_frames=self.video_skip_frames,
            )

        eval_option_metrics.update(
            runner._env.calc_eval_metrics(
                random_trajectories, is_option_trajectories=True
            )
        )
        with global_context.GlobalContext({"phase": "eval", "policy": "option"}):
            log_performance_ex(
                runner.step_itr,
                TrajectoryBatch.from_trajectory_list(
                    self._env_spec, random_trajectories
                ),
                discount=self.discount,
                additional_records=eval_option_metrics,
            )
        self._log_eval_metrics(runner)
