import copy
import gc
import random
import time
import warnings

import hydra
import numpy as np
import torch
import torch.autograd.profiler as profiler
from hydra.experimental import compose, initialize
from omegaconf import OmegaConf
from torch._C import device

import mbrl.util
import mbrl.util.common
import mbrl.util.mujoco as mujoco_util
from mbrl.algorithms.uasc import rollout_model_and_populate_imagined_buffer
from mbrl.models.one_dim_tr_model import OneDTransitionRewardModel
from mbrl.third_party.unrolled_actor_soft_critic.agent import Agent
from mbrl.third_party.unrolled_actor_soft_critic.agent.unrolling import (
    unroll_actor_adjoint,
    unroll_actor_direct,
)
from mbrl.third_party.unrolled_actor_soft_critic.buffer import CircularReplayBuffer


def timed_test():
    # Build a simple environment based on cartpole:
    initialize(config_path="conf", job_name="main")
    cfg = compose(config_name="main", overrides=["algorithm=uasc", "overrides=uasc_hopper", "algorithm.agent.actor_train.rollouts_per_batch=196", "algorithm.agent.actor_train.gradient_clip=-1.0", "device=cpu"])
    device = torch.device(cfg.device)

    warnings.filterwarnings('ignore', "`(np.object|np.bool)` is a deprecated alias for the builtin")

    # Create environment:
    assert cfg.algorithm.name == "uasc"
    env, termination_fn, _ = mujoco_util.make_env(cfg)
    test_env, *_ = mujoco_util.make_env(cfg)
    mbrl.planning.complete_agent_cfg(env, cfg.algorithm.agent)

    obs_shape = env.observation_space.shape
    act_shape = env.action_space.shape

    # Replay buffer:
    replay_buffer = mbrl.util.common.create_replay_buffer(cfg, obs_shape, act_shape, rng=np.random.default_rng(seed=0))
    mbrl.util.common.rollout_agent_trajectories(env, cfg.algorithm.initial_exploration_steps, mbrl.planning.RandomAgent(env), {}, replay_buffer=replay_buffer)

    # Dynamics model:
    dynamics_model = mbrl.util.common.create_one_dim_tr_model(cfg, obs_shape, act_shape)
    model_env = mbrl.models.ModelEnv(env, dynamics_model, termination_fn, None)

    agent = hydra.utils.instantiate(cfg.algorithm.agent)
    imagined_buffer = CircularReplayBuffer(obs_shape, act_shape, device=device)
    imagined_buffer.resize(240)
    rollout_model_and_populate_imagined_buffer(model_env, replay_buffer, agent, imagined_buffer, cfg.algorithm.rollout_samples_action, 5, 49)

    rollout_lengths = [1, 2, 4, 6, 8, 10, 20, 40, 80, 100]
    num_iter = 20

    for rollout_length in rollout_lengths:
        # One warmup iteration 

        # Direct
        agent_direct = copy.deepcopy(agent)
        gc.collect()
        start_time = time.perf_counter()
        unroll_actor_direct(cfg.algorithm.agent, imagined_buffer, dynamics_model, agent_direct, seed=cfg.seed, device=device, logger=None, profile=False, rollout_length=rollout_length, optimizer_override=torch.optim.SGD(agent_direct.actor.parameters(), lr=1e-4))
        for j in range(num_iter):
            unroll_actor_direct(cfg.algorithm.agent, imagined_buffer, dynamics_model, agent_direct, seed=cfg.seed, device=device, logger=None, profile=False, rollout_length=rollout_length, optimizer_override=torch.optim.SGD(agent_direct.actor.parameters(), lr=1e-4))
        direct_time = (time.perf_counter() - start_time)/num_iter

        # Adjoint
        agent_adjoint = copy.deepcopy(agent)
        gc.collect()
        unroll_actor_adjoint(cfg.algorithm.agent, imagined_buffer, dynamics_model, agent_adjoint, seed=cfg.seed, device=device, logger=None, profile=False, rollout_length=rollout_length, optimizer_override=torch.optim.SGD(agent_adjoint.actor.parameters(), lr=1e-4))
        start_time = time.perf_counter()
        for j in range(num_iter):
            unroll_actor_adjoint(cfg.algorithm.agent, imagined_buffer, dynamics_model, agent_adjoint, seed=cfg.seed, device=device, logger=None, profile=False, rollout_length=rollout_length, optimizer_override=torch.optim.SGD(agent_adjoint.actor.parameters(), lr=1e-4))
        adjoint_time = (time.perf_counter() - start_time)/num_iter

        print(f"{rollout_length}\t{direct_time}\t{adjoint_time}")

if __name__ == "__main__":
    timed_test()
