#
# Copyright (c) 2020-2022 ANONYMOUS RESEARCHERS.  All rights reserved.
#
# ANONYMOUS RESEARCHERS and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from ANONYMOUS RESEARCHERS is strictly prohibited.
#
#
"""
Stein MPC.
"""
# Third Party
import os
import torch
import numpy as np

# Storm
# from storm.mpc.control.control_base import Controller

# Goal sets.
from ..stein.svgd import SVGD
from ..stein.kernels import RBFMedianKernel
from ..util.optim import FullBatchLBFGS


class SteinMPC(object):
    """
    .. inheritance-diagram:: SteinMPC
       :parts: 1
    """

    def __init__(
        self,
        cost_fn,
        num_particles,
        d_action,
        horizon,
        n_iters=1,
        rollout_fn=None,
        sample_mode="mean",
        hotstart=True,
        seed=0,
        init_cov=0.1,
        kernel_sigma=None,
        stein_alpha=1.,  # Repulsive force parameter.
        conv_eps=0.05,  # Gradient norm as a percentage of total norm to be considered converged.
        shift_mode="best",  # Shift the trajectories. Options: "shift", "best", "reset"
        log_prior=None,  # A callable prior function.
        weight_fn=None,
        optim_type="adam",
        optim_params={"lr": 1e-2},
        tensor_args={"device": torch.device("cpu"), "dtype": torch.float32},
    ):
        """
        Parameters
        __________
        num_particles : int
            Number of action sequences sampled at every iteration
        """
        self.n_iters = n_iters
        self._rollout_fn = rollout_fn
        self.sample_mode = sample_mode
        self.hotstart = hotstart
        self.shift_mode = shift_mode
        self.tensor_args = tensor_args

        self.cost_fn = cost_fn
        self.num_particles = num_particles
        self.horizon = horizon
        self.d_action = d_action
        self.optim_params = optim_params
        self.optim_type = optim_type
        # If provided, will use these weights instead of those from Stein to calculate best particle.
        self.weight_fn = weight_fn
        self.log_prior = log_prior

        self.init_cov = init_cov
        self.conv_eps = conv_eps

        self.num_steps = 0

        # Robot state.
        self.state = torch.zeros(3 * d_action + 1, **self.tensor_args)

        if kernel_sigma is None:
            # Median heuristic, where gamma = 1 / (sigma * sqrt(2 * DIM))
            gamma = 1. / np.sqrt(2 * self.horizon * self.d_action)
            kernel = RBFMedianKernel(gamma=gamma)
        else:
            kernel = RBFMedianKernel(sigma=kernel_sigma)
        # kernel = RBFMedianKernel()
        p_size = (self.num_particles, self.horizon * self.d_action)
        init_particles = torch.normal(0, self.init_cov, p_size, **self.tensor_args)

        self.stein = SVGD(init_particles, self.log_likelihood, kernel, alpha=1.)
        self.optimizer = self.init_optimizer(optim_type, optim_params)

    def init_optimizer(self, optim_type, optim_params={}):
        if optim_type == "sgd":
            return torch.optim.SGD([self.stein.optim_parameters()], **optim_params)
        elif optim_type == "adam":
            return torch.optim.Adam([self.stein.optim_parameters()], **optim_params)
        elif optim_type == "lbfgs":
            return FullBatchLBFGS([self.stein.optim_parameters()], **optim_params)
        else:
            raise Exception("Unrecognized optimizer type: {}".format(optim_type))

    def update_goal(self, **kwargs):
        self.cost_fn.update_params(**kwargs)

    def update_cost(self, cost, weight_fn=None, log_prior=None):
        self.cost_fn = cost
        self.weight_fn = weight_fn
        self.log_prior = log_prior

    def action_particles(self):
        return self.stein.particles().view(self.num_particles, self.horizon, self.d_action)

    def rollout(self, state=None, actions=None):
        if state is None:
            state = self.state
        if actions is None:
            actions = self.action_particles()
        return self._rollout_fn(state, actions)
        # TODO: Make more efficient.
        # Check whether we already have the rollouts for this function.
        # have_rollouts = torch.equal(state, self.state) and torch.equal(actions, self.action_particles())
        # if not have_rollouts and self._trajectories is not None:
        #     self._trajectories = self._rollout_fn(state, actions)
        # return self._trajectories

    def _get_action_seq(self, mode="mean", return_idx=False):
        particles = self.stein.particles()
        action, idx = None, None
        if mode == "mean":
            action = torch.mean(particles, dim=0)
        elif mode == "best":
            weights = self.calc_weights()
            idx = weights.argmax().item()
            action = particles[weights.argmax(), :]
        elif mode == "sample":
            idx = torch.randint(0, self.num_particles, (1,)).item()
            action = particles[idx, :]
        else:
            raise ValueError("Unidentified sampling mode in get_next_action")

        if return_idx:
            return action.view(self.horizon, self.d_action), idx
        else:
            return action.view(self.horizon, self.d_action)

    def calc_weights(self):
        # Calculate the weights of the current Stein particles.
        if self.weight_fn is not None:
            actions = self.action_particles()
            trajectories = self._rollout_fn(self.state, actions)
            return self.weight_fn(trajectories)

        return self.stein.calc_weights()

    def log_likelihood(self, particles):
        actions = particles.view(self.num_particles, self.horizon, self.d_action)
        trajectories = self._rollout_fn(self.state, actions)
        log_likelihood = -self.cost_fn(trajectories, actions)

        if log_likelihood.ndim > 1:
            # Make sure that there is one cost per particle.
            # TODO: Allow other methods of combining costs.
            log_likelihood = log_likelihood.view(self.num_particles, -1).sum(-1)

        # If there is a prior, add that too.
        if self.log_prior is not None:
            log_likelihood = log_likelihood + self.log_prior(trajectories, actions)

        return log_likelihood

    def _step_closure(self):
        def closure():
            self.optimizer.zero_grad()
            loss = self.stein.update(create_graph=True)
            return -loss.sum()

        loss = closure()
        options = {'closure': closure, 'current_loss': loss}
        self.optimizer.step(options)

    def _step(self):
        self.optimizer.zero_grad()
        self.stein.update()
        self.optimizer.step()

    def _save_data(self, idx, path):
        # Create the parent directory if it does not already exist.
        if not os.path.exists(path):
            os.makedirs(path)

        file_path = os.path.join(path, "{:03d}.npy".format(idx))
        particles = self.action_particles()
        np.save(file_path, particles.detach().cpu().numpy())

    def optimize(self, state, shift_steps=1, n_iters=None,
                 return_idx=False, save_data=None, stop_early=False):
        """
        Optimize for best action at current state

        Parameters
        ----------
        state : torch.Tensor
            state to calculate optimal action from

        Returns
        -------
        action : torch.Tensor
            next action to execute
        """
        n_iters = n_iters if n_iters is not None else self.n_iters
        # get input device:
        inp_device = state.device
        inp_dtype = state.dtype
        state.to(**self.tensor_args)
        self.state = state

        # Reset the optimizer before optimizing.
        self.optimizer = self.init_optimizer(self.optim_type, self.optim_params)

        # shift distribution to hotstart from previous timestep
        if self.hotstart:
            self._shift(shift_steps)
        else:
            self.reset_distribution()

        # Initialize visualization data to save.
        if save_data is not None:
            grad_norms = []
            self._save_data(0, save_data)

        for i in range(n_iters):
            if self.optim_type == "lbfgs":
                self._step_closure()
            else:
                self._step()

            # Save the current action sequence to visualize.
            if save_data is not None:
                grad_norms.append(self.stein.calc_grad_percentage(mean=True).cpu().squeeze().numpy())
                self._save_data(i + 1, save_data)

            # generate random simulated trajectories
            # trajectory = self.generate_rollouts(state)

            # update distribution parameters
            # self._update_distribution(trajectory)

            # check if converged
            if stop_early:
                if self.check_convergence():
                    print("Converged!", i)
                    break

        # calculate best action
        curr_action_seq = self._get_action_seq(mode=self.sample_mode, return_idx=return_idx)

        if save_data is not None:
            np.save(os.path.join(save_data, "grads.npy"), grad_norms)

        self.num_steps += 1

        if return_idx:
            curr_action_seq, action_idx = curr_action_seq
            return curr_action_seq.to(inp_device, dtype=inp_dtype), action_idx
        else:
            return curr_action_seq.to(inp_device, dtype=inp_dtype)

    def check_convergence(self):
        return self.stein.calc_grad_percentage(mean=True).squeeze() < self.conv_eps

    def _update_distribution(self, trajectories):
        """
        Update distribution of particles based on given trajectories.
        """
        # TODO
        pass

    def generate_rollouts(self, state):
        """
        Samples a batch of actions, rolls out trajectories for each particle
        and returns the resulting observations, costs, actions

        Parameters
        ----------
        state : dict or np.ndarray
            Initial state to set the simulation env to
        """
        # TODO
        pass

    def _shift(self, shift_steps=1, repeat=False):
        """
        Predict mean for the next time step by
        shifting the current mean forward by one step
        """
        p_size = (self.num_particles, self.horizon, self.d_action)
        if self.shift_mode == "reset":
            particles = torch.normal(0, self.init_cov, p_size, **self.tensor_args)
            shift_steps = 0  # Ensure not shifted.
        elif self.shift_mode == "best":
            # Get the chosen trajectory from the current particle set.
            seq, idx = self._get_action_seq(mode=self.sample_mode, return_idx=True)
            particles = seq.tile(self.num_particles, 1, 1)  # Repeat the selected trajectory for each particle.

            # Add some Gaussian noise.
            noise = torch.normal(0, self.init_cov, p_size, **self.tensor_args)
            particles = particles + noise
        elif self.shift_mode == "shift":
            particles = self.action_particles().view(*p_size)
        else:
            raise Exception("SteinMPC: Unrecognized shift mode: " + self.shift_mode)

        # Shift the trajectories by the shift step.
        if shift_steps > 0:
            particles = particles.roll(-shift_steps, 1)
            # Fill the last elements of the trajectory.
            if repeat:
                particles[:, -shift_steps:, :] = particles[:, -shift_steps - 1, :].unsqueeze(1)
            else:
                particles[:, -shift_steps:, :] = torch.zeros(1, **self.tensor_args)

        # Reset the distribution.
        self.stein.reset(particles.view(self.num_particles, -1))

    def reset_distribution(self, init_particles=None):
        """
        Reset control distribution
        """
        if init_particles is None:
            p_size = (self.num_particles, self.horizon * self.d_action)
            init_particles = torch.normal(0, self.init_cov, p_size, **self.tensor_args)

        self.stein.reset(init_particles)
        self.num_steps = 0

    def _calc_val(self, cost_seq, act_seq):
        raise NotImplementedError("_calc_val not implemented")
