# Copyright 2023 OmniSafeAI Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


from __future__ import annotations

import os
import random
import sys
import time
from collections import deque
from typing import Callable
import math

import numpy as np
try: 
    from isaacgym import gymutil
except ImportError:
    pass
import torch
import torch.nn as nn
import torch.optim
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.utils.data import DataLoader, TensorDataset

from safepo.common.buffer import VectorizedOnPolicyBuffer
from safepo.common.env import make_sa_mujoco_env, make_sa_isaac_env
from safepo.common.logger import EpochLogger
from safepo.common.model import ActorNoCritic
from safepo.utils.config import single_agent_args, isaac_gym_map, parse_sim_params

CONJUGATE_GRADIENT_ITERS=50
TRPO_SEARCHING_STEPS=100
BETA = 0.65

default_cfg = {
    'hidden_sizes': [64, 64],
    'gamma': 0.99,
    'gamma_c': 0.99,
    'target_kl': 0.01,
    'target_kl_init': 0.01,
    'batch_size': 20000,
    'learning_iters': 0,
    'max_grad_norm': 0.5,
}

isaac_gym_specific_cfg = {
    'total_steps': 100000000,
    'steps_per_epoch': 32768,
    'hidden_sizes': [1024, 1024, 512],
    'gamma': 0.96,
    'target_kl': 0.016,
    'num_mini_batch': 4,
    'use_value_coefficient': True,
    'learning_iters': 8,
    'max_grad_norm': 1.0,
    'use_critic_norm': False,
}


def get_flat_params_from(model: torch.nn.Module) -> torch.Tensor:
    flat_params = []
    for _, param in model.named_parameters():
        if param.requires_grad:
            data = param.data
            data = data.view(-1)  # flatten tensor
            flat_params.append(data)
    assert flat_params, "No gradients were found in model parameters."
    return torch.cat(flat_params)


def conjugate_gradients(
    fisher_product: Callable[[torch.Tensor], torch.Tensor],
    policy: ActorNoCritic,
    fvp_obs: torch.Tensor,
    vector_b: torch.Tensor,
    num_steps: int = 10,
    residual_tol: float = 1e-10,
    eps: float = 1e-6,
) -> torch.Tensor:
    vector_x = torch.zeros_like(vector_b)
    vector_r = vector_b - fisher_product(vector_x, policy, fvp_obs)
    vector_p = vector_r.clone()
    rdotr = torch.dot(vector_r, vector_r)

    for _ in range(num_steps):
        vector_z = fisher_product(vector_p, policy, fvp_obs)
        alpha = rdotr / (torch.dot(vector_p, vector_z) + eps)
        vector_x += alpha * vector_p
        vector_r -= alpha * vector_z
        new_rdotr = torch.dot(vector_r, vector_r)
        if torch.sqrt(new_rdotr) < residual_tol:
            break
        vector_mu = new_rdotr / (rdotr + eps)
        vector_p = vector_r + vector_mu * vector_p
        rdotr = new_rdotr
    return vector_x


def set_param_values_to_model(model: torch.nn.Module, vals: torch.Tensor) -> None:
    assert isinstance(vals, torch.Tensor)
    i: int = 0
    for _, param in model.named_parameters():
        if param.requires_grad:  # param has grad and, hence, must be set
            orig_size = param.size()
            size = np.prod(list(param.size()))
            new_values = vals[i : int(i + size)]
            # set new param values
            new_values = new_values.view(orig_size)
            param.data = new_values
            i += int(size)  # increment array position
    assert i == len(vals), f"Lengths do not match: {i} vs. {len(vals)}"


def get_flat_gradients_from(model: torch.nn.Module) -> torch.Tensor:
    grads = []
    for _, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            grad = param.grad
            grads.append(grad.view(-1))  # flatten tensor and append
    assert grads, "No gradients were found in model parameters."
    return torch.cat(grads)


def fvp(
    params: torch.Tensor,
    policy: ActorNoCritic,
    fvp_obs: torch.Tensor,
) -> torch.Tensor:
    policy.actor.zero_grad()
    current_distribution = policy.actor(fvp_obs)
    with torch.no_grad():
        old_distribution = policy.actor(fvp_obs)
    kl = torch.distributions.kl.kl_divergence(
        old_distribution, current_distribution
    ).mean()

    grads = torch.autograd.grad(kl, tuple(policy.actor.parameters()), create_graph=True)
    flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

    kl_p = (flat_grad_kl * params).sum()
    grads = torch.autograd.grad(
        kl_p,
        tuple(policy.actor.parameters()),
        retain_graph=False,
    )

    flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads])

    # return flat_grad_grad_kl + params * 0.1
    return flat_grad_grad_kl + params * 0.02


def main(args, cfg_env=None):
    # set the random seed, device and number of threads
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(4)
    global BETA
    BETA = args.beta if args.beta is not None else BETA
    device = torch.device(f'{args.device}:{args.device_id}')


    if args.task not in isaac_gym_map.keys():
        env, obs_space, act_space = make_sa_mujoco_env(
            num_envs=args.num_envs, env_id=args.task, seed=args.seed
        )
        eval_env, _, _ = make_sa_mujoco_env(num_envs=1, env_id=args.task, seed=None)
        config = default_cfg

    else:
        sim_params = parse_sim_params(args, cfg_env, None)
        env = make_sa_isaac_env(args=args, cfg=cfg_env, sim_params=sim_params)
        eval_env = env
        obs_space = env.observation_space
        act_space = env.action_space
        args.num_envs = env.num_envs
        config = isaac_gym_specific_cfg

    # set training steps
    steps_per_epoch = config.get("steps_per_epoch", args.steps_per_epoch)
    total_steps = config.get("total_steps", args.total_steps)
    local_steps_per_epoch = steps_per_epoch // args.num_envs
    epochs = total_steps // steps_per_epoch
    # create the actor-critic module
    policy = ActorNoCritic(
        obs_dim=obs_space.shape[0],
        act_dim=act_space.shape[0],
        hidden_sizes=config["hidden_sizes"],
    ).to(device)

    # create the vectorized on-policy buffer
    buffer = VectorizedOnPolicyBuffer(
        obs_space=obs_space,
        act_space=act_space,
        size=local_steps_per_epoch,
        device=device,
        num_envs=args.num_envs,
        gamma=config["gamma"],
        gamma_c=config["gamma_c"],
        standardized_adv_c=False,
        standardized_adv_r=False,
    )

    # set up the logger
    dict_args = vars(args)
    dict_args.update(config)
    logger = EpochLogger(
        log_dir=args.log_dir,
        seed=str(args.seed),
    )
    rew_deque = deque(maxlen=50)
    cost_deque = deque(maxlen=50)
    len_deque = deque(maxlen=50)
    eval_rew_deque = deque(maxlen=100)
    eval_cost_deque = deque(maxlen=100)
    eval_len_deque = deque(maxlen=100)
    logger.save_config(dict_args)
    logger.setup_torch_saver(policy.actor)
    logger.log("Start with training.")
    obs, _ = env.reset()
    obs = torch.as_tensor(obs, dtype=torch.float32, device=device)
    ep_ret, ep_cost, ep_len = (
        np.zeros(args.num_envs),
        np.zeros(args.num_envs),
        np.zeros(args.num_envs),
    )
    # training loop
    for epoch in range(epochs):
        # config["target_kl"] = config["target_kl_init"] * (1 - epoch/epochs)
        rollout_start_time = time.time()
        # collect samples until we have enough to update
        for steps in range(local_steps_per_epoch):
            with torch.no_grad():
                act, log_prob, value_r, value_c = policy.step(obs, deterministic=False)
            action = act.detach().squeeze() if args.task in isaac_gym_map.keys() else act.detach().squeeze().cpu().numpy()
            next_obs, reward, cost, terminated, truncated, info = env.step(action)

            ep_ret += reward.cpu().numpy() if args.task in isaac_gym_map.keys() else reward
            ep_cost += cost.cpu().numpy() if args.task in isaac_gym_map.keys() else cost
            ep_len += 1
            next_obs, reward, cost, terminated, truncated = (
                torch.as_tensor(x, dtype=torch.float32, device=device)
                for x in (next_obs, reward, cost, terminated, truncated)
            )
            if "final_observation" in info:
                info["final_observation"] = np.array(
                    [
                        array if array is not None else np.zeros(obs.shape[-1])
                        for array in info["final_observation"]
                    ],
                )
                info["final_observation"] = torch.as_tensor(
                    info["final_observation"],
                    dtype=torch.float32,
                    device=device,
                )
            buffer.store(
                obs=obs,
                act=act,
                reward=reward,
                cost=cost,
                value_r=value_r,
                value_c=value_c,
                log_prob=log_prob,
            )

            obs = next_obs
            epoch_end = steps >= local_steps_per_epoch - 1
            for idx, (done, time_out) in enumerate(zip(terminated, truncated)):
                if epoch_end or done or time_out:
                    last_value_r = torch.zeros(1, device=device)
                    last_value_c = torch.zeros(1, device=device)
                    if not done:
                        if epoch_end:
                            with torch.no_grad():
                                _, _, last_value_r, last_value_c = policy.step(
                                    obs[idx], deterministic=False
                                )
                        if time_out:
                            with torch.no_grad():
                                _, _, last_value_r, last_value_c = policy.step(
                                    info["final_observation"][idx], deterministic=False
                                )
                        last_value_r = last_value_r.unsqueeze(0)
                        last_value_c = last_value_c.unsqueeze(0)
                    if done or time_out:
                        rew_deque.append(ep_ret[idx])
                        cost_deque.append(ep_cost[idx])
                        len_deque.append(ep_len[idx])                            
                        logger.store(
                            **{
                                "Metrics/EpRet": np.mean(rew_deque),
                                "Metrics/EpCost": np.mean(cost_deque),
                                "Metrics/EpLen": np.mean(len_deque),
                            }
                        )
                        ep_ret[idx] = 0.0
                        ep_cost[idx] = 0.0
                        ep_len[idx] = 0.0
                        logger.logged = False

                    buffer.finish_path(
                        last_value_r=last_value_r, last_value_c=last_value_c, idx=idx
                    )
        rollout_end_time = time.time()
        safe_rews = []
        safe_eps = 0
        for i in range(len(rew_deque)):
            safe_eps += 1 if cost_deque[i] == 0 else 0
            safe_rews.append(rew_deque[i])
        safe_prob = safe_eps/len(rew_deque)
        safe_rew = np.mean(safe_rews) if safe_eps > 0 else 0.0
        logger.store(
            **{
                "Metrics/SafeProb": safe_prob,
                "Metrics/SafeEpRet": safe_rew,
            }
        )

        eval_start_time = time.time()

        # eval_episodes = 25 #if epoch < epochs - 1 else 100
        eval_episodes = 0 if (epoch < epochs - 1) else 100
        if args.use_eval:
            for _ in range(eval_episodes):
                eval_done = False
                eval_obs, _ = eval_env.reset()
                eval_obs = torch.as_tensor(eval_obs, dtype=torch.float32, device=device)
                eval_rew, eval_cost, eval_len = 0.0, 0.0, 0.0
                while not eval_done:
                    with torch.no_grad():
                        act, log_prob, value_r, value_c = policy.step(eval_obs, deterministic=True)
                    next_obs, reward, cost, terminated, truncated, info = env.step(
                        act.detach().squeeze().cpu().numpy()
                    )
                    next_obs = torch.as_tensor(next_obs, dtype=torch.float32, device=device)
                    eval_rew += reward
                    eval_cost += cost
                    eval_len += 1
                    eval_done = terminated[0] or truncated[0]
                    eval_obs = next_obs
                eval_rew_deque.append(eval_rew)
                eval_cost_deque.append(eval_cost)
                eval_len_deque.append(eval_len)
            logger.store(
                **{
                    "Metrics/EvalEpRet": np.mean(eval_rew_deque) if eval_episodes > 0 else 0.0,
                    "Metrics/EvalEpCost": np.mean(eval_cost_deque) if eval_episodes > 0 else 0.0,
                    "Metrics/EvalEpLen": np.mean(eval_len_deque) if eval_episodes > 0 else 0.0,
                }
            )

        eval_end_time = time.time()

        # update lagrange multiplier
        ep_costs = logger.get_stats("Metrics/EpCost")

        # update policy
        data = buffer.get()
        fvp_obs = data["obs"][:: 1]
        theta_old = get_flat_params_from(policy.actor)
        policy.actor.zero_grad()

        # compute loss_pi_r
        temp_distribution = policy.actor(data["obs"])
        log_prob = temp_distribution.log_prob(data["act"]).sum(dim=-1)
        ratio = torch.exp(log_prob - data["log_prob"])
        act = policy.step(data["obs"], deterministic=False)[0]
        loss_pi_r = (ratio * data["target_value_r"]).mean() #+ 0.01 * temp_distribution.entropy().mean()
        
        # loss_before = loss_pi_r.item()
        old_distribution = policy.actor(data["obs"])

        loss_pi_r.backward()
        # clip_grad_norm_(policy.parameters(), config["max_grad_norm"])

        grads_r = get_flat_gradients_from(policy.actor)
        x_r = conjugate_gradients(fvp, policy, fvp_obs, grads_r, CONJUGATE_GRADIENT_ITERS)
        assert torch.isfinite(x_r).all(), "x is not finite"
        xHx = torch.dot(x_r, fvp(x_r, policy, fvp_obs))
        assert xHx.item() >= 0, "xHx is negative"
        alpha = torch.sqrt(2 * config['target_kl'] / (xHx + 1e-8))
        step_direction_r = x_r * alpha
        assert torch.isfinite(step_direction_r).all(), "step_direction is not finite"
        
        policy.actor.zero_grad()

        # compute loss_pi_c
        temp_distribution = policy.actor(data["obs"])
        log_prob = temp_distribution.log_prob(data["act"]).sum(dim=-1)
        ratio = torch.exp(log_prob - data["log_prob"])
        act = policy.step(data["obs"], deterministic=False)[0]
        loss_pi_c = -(ratio * data["target_value_c"]).mean() #+ 0.01 * temp_distribution.entropy().mean()

        # loss_before = loss_pi_r.item()
        old_distribution = policy.actor(data["obs"])

        loss_pi_c.backward()
        # clip_grad_norm_(policy.parameters(), config["max_grad_norm"])

        grads_c = get_flat_gradients_from(policy.actor)
        x_c = conjugate_gradients(fvp, policy, fvp_obs, grads_c, CONJUGATE_GRADIENT_ITERS)
        assert torch.isfinite(x_c).all(), "x is not finite"
        xHx = torch.dot(x_c, fvp(x_c, policy, fvp_obs))
        assert xHx.item() >= 0, "xHx is negative"
        alpha = torch.sqrt(2 * config['target_kl'] / (xHx + 1e-8))
        step_direction_c = x_c * alpha
        assert torch.isfinite(step_direction_c).all(), "step_direction is not finite"
        
        weight = ((-grads_c.dot(step_direction_r) + BETA * grads_c.dot(step_direction_c))/(-grads_c.dot(step_direction_r) + grads_c.dot(step_direction_c))).item()
        weight = weight if weight > 0 else 0.0
        step_direction = (1 - weight) * step_direction_r + (weight) * step_direction_c
        step_frac = 1.0
        new_theta = theta_old + step_frac * step_direction
        set_param_values_to_model(policy.actor, new_theta)
        current_distribution = policy.actor(data["obs"])
        log_prob = current_distribution.log_prob(data["act"]).sum(dim=-1)
        ratio = torch.exp(log_prob - data["log_prob"])
        actual_improve_c = (-(ratio * data["target_value_c"]).mean()).item() - loss_pi_c.item()
        final_kl = (
            torch.distributions.kl.kl_divergence(
                old_distribution, current_distribution
            )
            .mean()
            .item()
        )
        acceptance_step = 0
        for i in range(TRPO_SEARCHING_STEPS):
            if ((actual_improve_c < 0) or (final_kl > config["target_kl"])):
                step_frac *= 0.8
                new_theta = theta_old + step_frac * step_direction
                set_param_values_to_model(policy.actor, new_theta)
                current_distribution = policy.actor(data["obs"])
                log_prob = current_distribution.log_prob(data["act"]).sum(dim=-1)
                ratio = torch.exp(log_prob - data["log_prob"])
                actual_improve_c = (-(ratio * data["target_value_c"]).mean()).item() - loss_pi_c.item()
                final_kl = (
                    torch.distributions.kl.kl_divergence(
                        old_distribution, current_distribution
                    )
                    .mean()
                    .item()
                )
                acceptance_step += 1
            else:
                break
        else:
            step_frac = 0.0
            acceptance_step = -1
            set_param_values_to_model(policy.actor, theta_old)
        current_distribution = policy.actor(data["obs"])
        log_prob = current_distribution.log_prob(data["act"]).sum(dim=-1)
        ratio = torch.exp(log_prob - data["log_prob"])
        loss_pi_c = -(ratio * data["target_value_c"]).mean()
        loss_pi_r = (ratio * data["target_value_r"]).mean()

        logger.store(
            **{
                "Misc/FinalStepNorm": torch.norm(step_frac * step_direction).mean().item(),
                "Misc/g_r_norm": torch.norm(grads_r).mean().item(),
                "Misc/g_c_norm": torch.norm(grads_c).mean().item(),
                "Misc/H_inv_g_r": x_r.norm().item(),
                "Misc/H_inv_g_c": x_c.norm().item(),
                "Misc/AcceptanceStep": acceptance_step,
                "Misc/Weight": weight,
                "Loss/Loss_actor_r": loss_pi_r.item(),
                "Loss/Loss_actor_c": loss_pi_c.item(),
                "Train/KL": final_kl,
            },
        )

        update_end_time = time.time()
        if not logger.logged:
            # log data
            logger.log_tabular("Metrics/EpRet")
            logger.log_tabular("Metrics/EpCost")
            logger.log_tabular("Metrics/EpLen")
            logger.log_tabular("Metrics/SafeProb")
            logger.log_tabular("Metrics/SafeEpRet")
            if args.use_eval:
                logger.log_tabular("Metrics/EvalEpRet")
                logger.log_tabular("Metrics/EvalEpCost")
                logger.log_tabular("Metrics/EvalEpLen")
            logger.log_tabular("Train/Epoch", epoch + 1)
            logger.log_tabular("Train/TotalSteps", (epoch + 1) * args.steps_per_epoch)
            logger.log_tabular("Train/KL")
            logger.log_tabular("Loss/Loss_actor_r")
            logger.log_tabular("Loss/Loss_actor_c")
            logger.log_tabular("Time/Rollout", rollout_end_time - rollout_start_time)
            if args.use_eval:
                logger.log_tabular("Time/Eval", eval_end_time - eval_start_time)
            logger.log_tabular("Time/Update", update_end_time - eval_end_time)
            logger.log_tabular("Time/Total", update_end_time - rollout_start_time)
            logger.log_tabular("Value/RewardAdv", data["adv_r"].mean().item())
            logger.log_tabular("Value/CostAdv", data["adv_c"].mean().item())
            logger.log_tabular("Misc/FinalStepNorm")
            logger.log_tabular("Misc/g_r_norm")
            logger.log_tabular("Misc/g_c_norm")
            logger.log_tabular("Misc/H_inv_g_r")
            logger.log_tabular("Misc/H_inv_g_c")
            logger.log_tabular("Misc/AcceptanceStep")
            logger.log_tabular("Misc/Weight")

            logger.dump_tabular()
            if (epoch+1) % 20 == 0 or epoch == 0:
                logger.torch_save(itr=epoch)
                if args.task not in isaac_gym_map.keys():
                    logger.save_state(
                        state_dict={
                            "Normalizer": env.obs_rms,
                        },
                        itr = epoch
                    )
    logger.close()


if __name__ == "__main__":
    args, cfg_env = single_agent_args()
    relpath = time.strftime("%Y-%m-%d-%H-%M-%S")
    subfolder = "-".join(["seed", str(args.seed).zfill(3)])
    relpath = "-".join([subfolder, relpath])
    algo = os.path.basename(__file__).split(".")[0]
    args.log_dir = os.path.join(args.log_dir, args.experiment, args.task, algo, relpath)
    if not args.write_terminal:
        terminal_log_name = "terminal.log"
        error_log_name = "error.log"
        terminal_log_name = f"seed{args.seed}_{terminal_log_name}"
        error_log_name = f"seed{args.seed}_{error_log_name}"
        sys.stdout = sys.__stdout__
        sys.stderr = sys.__stderr__
        if not os.path.exists(args.log_dir):
            os.makedirs(args.log_dir, exist_ok=True)
        with open(
            os.path.join(
                f"{args.log_dir}",
                terminal_log_name,
            ),
            "w",
            encoding="utf-8",
        ) as f_out:
            sys.stdout = f_out
            with open(
                os.path.join(
                    f"{args.log_dir}",
                    error_log_name,
                ),
                "w",
                encoding="utf-8",
            ) as f_error:
                sys.stderr = f_error
                main(args, cfg_env)
    else:
        main(args, cfg_env)
