# SAC with single Q function
import logging
import time
import os
 
import gym
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from typing import Tuple, Deque
from collections import deque
from omegaconf import DictConfig
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from dris.utils import avg_max_50p, make_env, make_ff_net, polyak_average, wandb_init_wrapper, wandb_finish_wrapper
from dris.dris import DRIS
import THIRD_PARTY.io as io_helper

logger = logging.getLogger(__name__)

class QNetwork(nn.Module):
    def __init__(self, envs: gym.vector.SyncVectorEnv, network_cfg: DictConfig) -> None:
        super().__init__()
        self.model = make_ff_net(
            inp_dim=np.prod(envs.single_observation_space.shape) + np.prod(envs.single_action_space.shape),
            out_dim=1,
            layers=network_cfg.layers,
            activation_fn=network_cfg.activation_fn,
            final_fn=network_cfg.final_fn,
        )

    def forward(self, s, a) -> torch.Tensor:
        x = torch.cat([s, a], 1)
        return self.model(x)

class Actor(nn.Module):
    LOG_STD_MIN = -5
    LOG_STD_MAX = 2

    def __init__(self, envs: gym.vector.SyncVectorEnv, network_cfg: DictConfig) -> None:
        super().__init__()
        self.model = make_ff_net(
            inp_dim=np.prod(envs.single_observation_space.shape),
            out_dim=network_cfg.layers[-1],
            layers=network_cfg.layers[:-1],
            activation_fn=network_cfg.activation_fn,
            final_fn=network_cfg.activation_fn,
        )
        self.model_mean = nn.Linear(
            network_cfg.layers[-1],
            np.prod(envs.single_action_space.shape)
        )

        self.model_logstd = nn.Sequential(
            nn.Linear(
                network_cfg.layers[-1],
                np.prod(envs.single_action_space.shape)
            ),
            nn.Tanh()
        )

        # action rescaling
        self.action_scale = torch.FloatTensor((
            envs.action_space.high - envs.action_space.low
        ) / 2.0)
        self.action_bias = torch.FloatTensor((
            envs.action_space.high + envs.action_space.low
        ) / 2.0)

    
    def forward(self, s) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.model(s)
        mean = self.model_mean(x)
        log_std = self.model_logstd(x)
        # TODO should be hyperparams?
        log_std = self.LOG_STD_MIN + 0.5 * (self.LOG_STD_MAX - self.LOG_STD_MIN) * (log_std + 1)
        return mean, log_std

    def get_action(self, x) -> torch.Tensor:
        mean, log_std = self(x)
        std = log_std.exp()

        normal_dist = torch.distributions.Normal(mean, std)
        x_t = normal_dist.rsample()
        y_t = torch.tanh(x_t)

        action = y_t * self.action_scale + self.action_bias
        log_prob = normal_dist.log_prob(x_t)

        # enforce action bounds
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)

        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super().to(device)

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).cuda()
        _,_,action = self.get_action(state)
        return action.cpu().data.numpy().flatten()

@torch.no_grad()
def eval_policy(policy, env_name, seed, eval_episodes=10):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + 100)

    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            action = policy.select_action(np.array(state))
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward

    avg_reward /= eval_episodes

    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward


def create_prune_callback(prune_dict):
    def step_callback(global_step: int, avg_return: float) -> bool:
        # expects hist_returns an array of size 10
        for s in sorted(prune_dict.keys()):
            if global_step > s: # crossing check post for first time
                if avg_return < prune_dict[s]:
                    logger.error(f"experiment pruned at {global_step}, current avg return ({avg_return}) lesser than expected ({prune_dict[s]})")
                    return True
                else: # successfully pass check post, delete it from checking
                    del prune_dict[s]
                    return False
            else: # first check post not reached
                return False
        return False
    return step_callback

def train(cfg: DictConfig):
    if cfg.task.prune:
        prune_callback = create_prune_callback(cfg.task.prune)
    if cfg.wandb.launch:
        wandb_init_wrapper(cfg)
    
    writer = SummaryWriter('summary')
    
    # seeding
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = cfg.torch_deterministic

    device = torch.device(cfg.device)

    # NOTE: SyncVectorEnv takes care of auto-reset on done
    envs = gym.vector.SyncVectorEnv([make_env(cfg.task.name, 0, 0, cfg.capture_video)])
    assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported'

    max_action = float(envs.single_action_space.high[0])
    actor = Actor(envs, cfg.hyperparams.actor).to(device)
    qf1 = QNetwork(envs, cfg.hyperparams.critic).to(device)
    dris = DRIS(cfg.dris)

    logger.info('qf1 = \n' + str(qf1.model))
    logger.info('qf1/dris = ' + str(dris))

    qf1_target = QNetwork(envs, cfg.hyperparams.critic).to(device)
    qf1_target.load_state_dict(qf1.state_dict())

    q_optimizer = optim.Adam(list(qf1.parameters()), lr=cfg.hyperparams.q_lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=cfg.hyperparams.policy_lr)

    if cfg.hyperparams.alpha_autotune:
        target_entropy = -torch.prod(torch.Tensor(envs.single_action_space.shape).to(device)).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
        alpha = log_alpha.exp().item()
        alpha_optimizer = optim.Adam([log_alpha], lr=cfg.hyperparams.q_lr)
    else:
        alpha = cfg.hyperparams.alpha

    envs.single_observation_space.dtype = np.float32
    rb = ReplayBuffer(
        cfg.hyperparams.buffer.size,
        envs.single_observation_space,
        envs.single_action_space,
        device=device,
        optimize_memory_usage=cfg.hyperparams.buffer.optimize_memory,
        handle_timeout_termination=True,
    )
    timer = {
        'env_interaction': deque([0.0] * 10, maxlen=10),
        'q_update': deque([0.0] * 10, maxlen=10),
        'policy_update': deque([0.0] * 10, maxlen=10),
        'target_network_update': deque([0.0] * 10, maxlen=10)
    }

    start_time = time.time()

    obs = envs.reset()
    mj_sim_list = []
    every_n_eval = cfg.total_timesteps // cfg.total_eval_cnt
    every_n_ckpt = every_n_eval * 2
    hist_episode_returns = deque([0] * 10, maxlen=10)
    for global_step in range(cfg.total_timesteps):
        # ========== START ENV INTERACTION ===============
        env_interaction_start_time = time.time()
        if global_step < cfg.hyperparams.learning_starts:
            actions = envs.action_space.sample()
        else:
            actions, _, _ = actor.get_action(torch.Tensor(obs).to(device))
            actions = actions.detach().cpu().numpy()
        
        next_obs, rewards, dones, infos = envs.step(actions)
        mj_sim_list.append(envs.envs[0].sim.get_state())

        for info in infos:
            if 'episode' in info.keys():
                logger.info(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                writer.add_scalar('charts/episodic_return', info['episode']['r'], global_step)
                writer.add_scalar('charts/episodic_length', info['episode']['l'], global_step)
                hist_episode_returns.appendleft(info['episode']['r'])
                if cfg.task.prune:
                    to_prune: bool = prune_callback(global_step, avg_max_50p(hist_episode_returns))
                    if to_prune:
                        envs.close()
                        writer.close()
                        if cfg.wandb.launch:
                            wandb_finish_wrapper()
                        return
                break
        
        real_next_obs = next_obs.copy()
        for idx, d in enumerate(dones):
            if d:
                real_next_obs[idx] = infos[idx]['terminal_observation']
        rb.add(obs, real_next_obs, actions, rewards, dones, infos)

        obs = next_obs
        timer['env_interaction'].appendleft(time.time() - env_interaction_start_time)
        # ========== END ENV INTERACTION ===============

        if global_step > cfg.hyperparams.learning_starts:
            data = rb.sample(cfg.hyperparams.batch_size)

            # =========== START Q UPDATE ====================
            q_update_start_time = time.time()
            with torch.no_grad():
                next_state_actions, next_state_log_prob, _ = actor.get_action(data.next_observations)
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf_next_target = qf1_next_target - alpha * next_state_log_prob
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * cfg.hyperparams.gamma * (qf_next_target).view(-1)

            qf1_a_values = qf1(data.observations, data.actions).view(-1)
            
            qf1_loss, qf1_info = dris.loss(qf1_a_values, next_q_value)
            dris.update(qf1_info, global_step)

            q_optimizer.zero_grad()
            qf1_loss.backward()
            nn.utils.clip_grad_norm_(list(qf1.parameters()), cfg.hyperparams.max_grad_norm) # extra addition
            q_optimizer.step()
            timer['q_update'].appendleft(time.time() - q_update_start_time)
            # =========== END Q UPDATE ====================

            if global_step % cfg.hyperparams.policy_freq == 0:
                for _ in range(
                    cfg.hyperparams.policy_freq
                ): # compensating for delay
                    # =========== START Policy UPDATE ====================
                    policy_update_start_time = time.time()
                    pi, log_pi, _ = actor.get_action(data.observations)
                    # qf1_pi = qf1(data.observations, pi)
                    actor_losses = -qf1.forward(data.observations.float(), pi)
                    actor_loss = dris.update_actor_loss(actor_losses) + (alpha * log_pi).mean()

                    actor_optimizer.zero_grad()
                    actor_loss.backward()
                    nn.utils.clip_grad_norm_(list(actor.parameters()), cfg.hyperparams.max_grad_norm)
                    actor_optimizer.step()

                    if cfg.hyperparams.alpha_autotune:
                        with torch.no_grad():
                            _, log_pi, _ = actor.get_action(data.observations)
                        alpha_loss = (-log_alpha * (log_pi + target_entropy)).mean()

                        alpha_optimizer.zero_grad()
                        alpha_loss.backward()
                        alpha_optimizer.step()
                        alpha = log_alpha.exp().item()
                    timer['policy_update'].appendleft(time.time() - policy_update_start_time)
                    # =========== END Policy UPDATE ====================

            # update target networks
            if global_step % cfg.hyperparams.target_freq == 0:
                target_network_update_start_time = time.time()           
                polyak_average(qf1.parameters(), qf1_target.parameters(), cfg.hyperparams.tau)
                timer['target_network_update'].appendleft(time.time() - target_network_update_start_time)
        
            if global_step % cfg.hyperparams.emit_freq == 0:
                writer.add_scalar('losses/qf1_loss', qf1_loss.item(), global_step)
                writer.add_scalar('losses/actor_loss', actor_loss.item(), global_step)
                writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)

                writer.add_scalar('qf1/values', qf1_a_values.mean().item(), global_step)
                for key in qf1_info:
                    writer.add_scalar(f'qf1/{key}', qf1_info[key], global_step)

                writer.add_scalar("losses/alpha", alpha, global_step)
                if cfg.hyperparams.alpha_autotune:
                    writer.add_scalar("losses/alpha_loss", alpha_loss.item(), global_step)
                
                # record timer
                time_env_interaction = avg_max_50p(timer['env_interaction'])
                time_q_update = avg_max_50p(timer['q_update'])
                time_policy_update = avg_max_50p(timer['policy_update'])
                time_target_network_update = avg_max_50p(timer['target_network_update'])

                writer.add_scalar('timer/env_interaction', time_env_interaction, global_step)
                writer.add_scalar('timer/q_update', time_q_update, global_step)
                writer.add_scalar('timer/policy_update', time_policy_update, global_step)
                writer.add_scalar('timer/target_network_update', time_target_network_update, global_step)
                writer.add_scalar('timer/total_time',
                    time_env_interaction +
                    time_q_update +
                    time_policy_update +
                    time_target_network_update,
                    global_step
                )
                writer.add_scalar('timer/total_update_time',
                    time_q_update +
                    time_policy_update +
                    time_target_network_update,
                    global_step
                )

            if (global_step+1) % every_n_ckpt == 0:
                os.makedirs(f'checkpoints/{global_step}', exist_ok=True)
                torch.save({
                    'global_step': global_step,
                    'actor_state_dict': actor.state_dict(),
                    'qf1_state_dict': qf1.state_dict(),
                    'target_qf1_state_dict': qf1_target.state_dict(),
                }, f'checkpoints/{global_step}/checkpoint.pt')
                # store replay observations (used for determining overestimation bias)
                random.shuffle(mj_sim_list)
                io_helper.dump(f'checkpoints/{global_step}/mj_sim.pkl', mj_sim_list[:2000])
                mj_sim_list = []
                with open(f'checkpoints/{global_step}/rb_size.txt', 'w') as f:
                    f.write('%d' % rb.size())
                torch.save({
                    'global_step': global_step,
                    'actor_optimizer_state_dict': actor_optimizer.state_dict(),
                    'q_optimizer_state_dict': q_optimizer.state_dict(),
                }, "latest.pt")
                io_helper.dump("rb.pkl", rb)

            if (global_step+1) % every_n_eval == 0:
                avg_reward = eval_policy(actor, cfg.task.name, 10)
                writer.add_scalar('eval/avg_r', avg_reward)
                with open(f'eval.txt', 'a') as f:
                    f.write(f'\n{global_step:09d}:{avg_reward}')

    envs.close()
    writer.close()
    if cfg.wandb.launch:
        wandb_finish_wrapper()
