# SAC with double Q function
import logging
import time
import os
 
import gym
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from typing import Tuple, Deque
from collections import deque
from omegaconf import DictConfig
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from dris.utils import avg_max_50p, make_env, make_ff_net, polyak_average, wandb_init_wrapper, wandb_finish_wrapper
from dris.dris import DRIS
import THIRD_PARTY.io as io_helper
from THIRD_PARTY import decorate_hook

class QNetwork(nn.Module):
    def __init__(self, envs: gym.vector.SyncVectorEnv, network_cfg: DictConfig) -> None:
        super().__init__()
        self.model = make_ff_net(
            inp_dim=np.prod(envs.single_observation_space.shape) + np.prod(envs.single_action_space.shape),
            out_dim=1,
            layers=network_cfg.layers,
            activation_fn=network_cfg.activation_fn,
            final_fn=network_cfg.final_fn,
        )

    def forward(self, s, a) -> torch.Tensor:
        x = torch.cat([s, a], 1)
        return self.model(x)

class Actor(nn.Module):
    LOG_STD_MIN = -5
    LOG_STD_MAX = 2

    def __init__(self, envs: gym.vector.SyncVectorEnv, network_cfg: DictConfig) -> None:
        super().__init__()
        self.model = make_ff_net(
            inp_dim=np.prod(envs.single_observation_space.shape),
            out_dim=network_cfg.layers[-1],
            layers=network_cfg.layers[:-1],
            activation_fn=network_cfg.activation_fn,
            final_fn=network_cfg.activation_fn,
        )
        self.model_mean = nn.Linear(
            network_cfg.layers[-1],
            np.prod(envs.single_action_space.shape)
        )

        self.model_logstd = nn.Sequential(
            nn.Linear(
                network_cfg.layers[-1],
                np.prod(envs.single_action_space.shape)
            ),
            nn.Tanh()
        )

        # action rescaling
        self.action_scale = torch.FloatTensor((
            envs.action_space.high - envs.action_space.low
        ) / 2.0)
        self.action_bias = torch.FloatTensor((
            envs.action_space.high + envs.action_space.low
        ) / 2.0)

    
    def forward(self, s) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.model(s)
        mean = self.model_mean(x)
        log_std = self.model_logstd(x)
        # TODO should be hyperparams?
        log_std = self.LOG_STD_MIN + 0.5 * (self.LOG_STD_MAX - self.LOG_STD_MIN) * (log_std + 1)
        return mean, log_std

    def get_action(self, x) -> torch.Tensor:
        mean, log_std = self(x)
        std = log_std.exp()

        normal_dist = torch.distributions.Normal(mean, std)
        x_t = normal_dist.rsample()
        y_t = torch.tanh(x_t)

        action = y_t * self.action_scale + self.action_bias
        log_prob = normal_dist.log_prob(x_t)

        # enforce action bounds
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super().to(device)


    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).cuda()
        _,_,action = self.get_action(state)
        return action.cpu().data.numpy().flatten()

@torch.no_grad()
def eval_policy(policy, env_name, seed, eval_episodes=10):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + 100)

    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            action = policy.select_action(np.array(state))
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward

    avg_reward /= eval_episodes

    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward


def train(cfg: DictConfig):
    logger = logging.getLogger(__name__)
    if cfg.wandb.launch:
        wandb_init_wrapper(cfg)
    
    writer = SummaryWriter('summary')
    
    # seeding
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = cfg.torch_deterministic

    device = torch.device(cfg.device)

    # NOTE: SyncVectorEnv takes care of auto-reset on done
    envs = gym.vector.SyncVectorEnv([make_env(cfg.task.name, 0, 0, cfg.capture_video)])
    assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported'

    max_action = float(envs.single_action_space.high[0])
    actor = Actor(envs, cfg.hyperparams.actor).to(device)
    qf1 = QNetwork(envs, cfg.hyperparams.critic).to(device)
    qf2 = QNetwork(envs, cfg.hyperparams.critic).to(device)
    dris1 = DRIS(cfg.dris)
    dris2 = DRIS(cfg.dris)

    logger.info('qf1 = \n' + str(qf1.model))
    logger.info('qf1/dris1 = ' + str(dris1))
    logger.info('qf2 = \n' + str(qf1.model))
    logger.info('qf2/dris2 = ' + str(dris2))

    qf1_target = QNetwork(envs, cfg.hyperparams.critic).to(device)
    qf2_target = QNetwork(envs, cfg.hyperparams.critic).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())

    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=cfg.hyperparams.q_lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=cfg.hyperparams.policy_lr)

    if cfg.hyperparams.alpha_autotune:
        target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        alpha_optimizer = optim.Adam([log_alpha], lr=cfg.hyperparams.q_lr)
    else:
        alpha = cfg.hyperparams.alpha

    envs.single_observation_space.dtype = np.float32
    rb = ReplayBuffer(
        cfg.hyperparams.buffer.size,
        envs.single_observation_space,
        envs.single_action_space,
        device=device,
        optimize_memory_usage=cfg.hyperparams.buffer.optimize_memory,
        handle_timeout_termination=True,
    )
    timer = {
        'env_interaction': deque([0.0] * 10, maxlen=10),
        'q_update': deque([0.0] * 10, maxlen=10),
        'policy_update': deque([0.0] * 10, maxlen=10),
        'target_network_update': deque([0.0] * 10, maxlen=10)
    }

    start_time = time.time()

    obs = envs.reset()
    mj_sim_list = []
    every_n_eval = cfg.total_timesteps // cfg.total_eval_cnt
    every_n_ckpt = every_n_eval * 2
    for global_step in range(cfg.total_timesteps):
        # ========== START ENV INTERACTION ===============
        env_interaction_start_time = time.time()
        if global_step < cfg.hyperparams.learning_starts:
            actions = envs.action_space.sample()
        else:
            actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
            actions = actions.detach().cpu().numpy()
        
        next_obs, rewards, dones, infos = envs.step(actions)
        mj_sim_list.append(envs.envs[0].sim.get_state())

        for info in infos:
            if 'episode' in info.keys():
                logger.info(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                writer.add_scalar('charts/episodic_return', info['episode']['r'], global_step)
                writer.add_scalar('charts/episodic_length', info['episode']['l'], global_step)
                break
        
        real_next_obs = next_obs.copy()
        for idx, d in enumerate(dones):
            if d:
                real_next_obs[idx] = infos[idx]['terminal_observation']
        rb.add(obs, real_next_obs, actions, rewards, dones, infos)

        obs = next_obs
        timer['env_interaction'].appendleft(time.time() - env_interaction_start_time)
        # ========== END ENV INTERACTION ===============

        if global_step > cfg.hyperparams.learning_starts:
            data = rb.sample(cfg.hyperparams.batch_size)

            # =========== START Q UPDATE ====================
            q_update_start_time = time.time()
            with torch.no_grad():
                next_state_actions, next_state_log_prob, _ = actor.get_action(data.next_observations)
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_prob
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * cfg.hyperparams.gamma * (min_qf_next_target).view(-1)

            qf1_a_values = qf1(data.observations, data.actions).view(-1)
            qf2_a_values = qf2(data.observations, data.actions).view(-1)
            
            qf1_loss, qf1_info = dris1.loss(qf1_a_values, next_q_value)
            qf2_loss, qf2_info = dris2.loss(qf2_a_values, next_q_value)

            qf_loss = qf1_loss + qf2_loss

            dris1.update(qf1_info, global_step)
            dris2.update(qf2_info, global_step)
            
            q_optimizer.zero_grad()
            qf_loss.backward()
            nn.utils.clip_grad_norm_(list(qf1.parameters()) + list(qf2.parameters()), cfg.hyperparams.max_grad_norm) # extra addition
            q_optimizer.step()
            timer['q_update'].appendleft(time.time() - q_update_start_time)
            # =========== END Q UPDATE ====================

            if global_step % cfg.hyperparams.policy_freq == 0:
                for _ in range(
                    cfg.hyperparams.policy_freq
                ): # compensating for delay
                    # =========== START Policy UPDATE ====================
                    policy_update_start_time = time.time()
                    pi, log_pi, _ = actor.get_action(data.observations)
                    qf1_pi = qf1(data.observations, pi)
                    qf2_pi = qf2(data.observations, pi)
                    min_qf_pi = torch.min(qf1_pi, qf2_pi).view(-1)
                    actor_loss = dris1.update_actor_loss(-min_qf_pi) + (alpha * log_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    nn.utils.clip_grad_norm_(list(actor.parameters()), cfg.hyperparams.max_grad_norm)
                    actor_optimizer.step()

                    if cfg.hyperparams.alpha_autotune:
                        with torch.no_grad():
                            _, log_pi, _ = actor.get_action(data.observations)
                        alpha_loss = (-log_alpha * (log_pi + target_entropy)).mean()

                        alpha_optimizer.zero_grad()
                        alpha_loss.backward()
                        alpha_optimizer.step()
                        alpha = log_alpha.exp().item()

                    timer['policy_update'].appendleft(time.time() - policy_update_start_time)
                    # =========== END Policy UPDATE ====================

            # update target networks
            if global_step % cfg.hyperparams.target_freq == 0:
                target_network_update_start_time = time.time()             
                polyak_average(qf1.parameters(), qf1_target.parameters(), cfg.hyperparams.tau)
                polyak_average(qf2.parameters(), qf2_target.parameters(), cfg.hyperparams.tau)
                timer['target_network_update'].appendleft(time.time() - target_network_update_start_time)
        
            if global_step % cfg.hyperparams.emit_freq == 0:
                writer.add_scalar('losses/qf1_loss', qf1_loss.item(), global_step)
                writer.add_scalar('losses/qf2_loss', qf2_loss.item(), global_step)
                writer.add_scalar('losses/qf_loss', qf_loss.item(), global_step)
                writer.add_scalar('losses/actor_loss', actor_loss.item(), global_step)
                writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)

                writer.add_scalar('qf1/values', qf1_a_values.mean().item(), global_step)
                for key in qf1_info:
                    writer.add_scalar(f'qf1/{key}', qf1_info[key], global_step)
                writer.add_scalar('qf2/values', qf2_a_values.mean().item(), global_step)
                for key in qf2_info:
                    writer.add_scalar(f'qf2/{key}', qf2_info[key], global_step)

                writer.add_scalar("losses/alpha", alpha, global_step)
                if cfg.hyperparams.alpha_autotune:
                    writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)

                # record timer
                time_env_interaction = avg_max_50p(timer['env_interaction'])
                time_q_update = avg_max_50p(timer['q_update'])
                time_policy_update = avg_max_50p(timer['policy_update'])
                time_target_network_update = avg_max_50p(timer['target_network_update'])

                writer.add_scalar('timer/env_interaction', time_env_interaction, global_step)
                writer.add_scalar('timer/q_update', time_q_update, global_step)
                writer.add_scalar('timer/policy_update', time_policy_update, global_step)
                writer.add_scalar('timer/target_network_update', time_target_network_update, global_step)
                writer.add_scalar('timer/total_time',
                    time_env_interaction +
                    time_q_update +
                    time_policy_update +
                    time_target_network_update,
                    global_step
                )
                writer.add_scalar('timer/total_update_time',
                    time_q_update +
                    time_policy_update +
                    time_target_network_update,
                    global_step
                )

            if (global_step+1) % every_n_ckpt == 0:
                os.makedirs(f'checkpoints/{global_step}', exist_ok=True)
                torch.save({
                    'global_step': global_step,
                    'actor_state_dict': actor.state_dict(),
                    'qf1_state_dict': qf1.state_dict(),
                    'qf2_state_dict': qf2.state_dict(),
                    'target_qf1_state_dict': qf1_target.state_dict(),
                    'target_qf2_state_dict': qf2_target.state_dict(),
                }, f'checkpoints/{global_step}/checkpoint.pt')
                # store replay observations (used for determining overestimation bias)
                random.shuffle(mj_sim_list)
                io_helper.dump(f'checkpoints/{global_step}/mj_sim.pkl', mj_sim_list[:2000])
                mj_sim_list = []
                with open(f'checkpoints/{global_step}/rb_size.txt', 'w') as f:
                    f.write('%d' % rb.size())
                torch.save({
                    'global_step': global_step,
                    'actor_optimizer_state_dict': actor_optimizer.state_dict(),
                    'q_optimizer_state_dict': q_optimizer.state_dict(),
                }, "latest.pt")
                io_helper.dump("rb.pkl", rb)

            if (global_step+1) % every_n_eval == 0:
                avg_reward = eval_policy(actor, cfg.task.name, 10)
                writer.add_scalar('eval/avg_r', avg_reward)
                with open(f'eval.txt', 'a') as f:
                    f.write(f'\n{global_step:09d}:{avg_reward}')

    
    envs.close()
    writer.close()
    if cfg.wandb.launch:
        wandb_finish_wrapper()
