import logging
import time
import os
 
import gym
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from omegaconf import DictConfig
from stable_baselines3.common.buffers import ReplayBuffer
from typing import Deque
from collections import deque
from torch.utils.tensorboard import SummaryWriter
from dris.utils import avg_max_50p, make_env, make_ff_net, polyak_average, wandb_init_wrapper, wandb_finish_wrapper
from dris.dris import DRIS
import THIRD_PARTY.io as io_helper
from THIRD_PARTY import decorate_hook
class QNetwork(nn.Module):
    def __init__(self, envs: gym.vector.SyncVectorEnv, network_cfg: DictConfig) -> None:
        super().__init__()
        self.model = make_ff_net(
            inp_dim=np.prod(envs.single_observation_space.shape) + np.prod(envs.single_action_space.shape),
            out_dim=1,
            layers=network_cfg.layers,
            activation_fn=network_cfg.activation_fn,
            final_fn=network_cfg.final_fn,
        )

    def forward(self, s, a) -> torch.Tensor:
        x = torch.cat([s, a], 1)
        return self.model(x)

class Actor(nn.Module):
    def __init__(self, envs: gym.vector.SyncVectorEnv, network_cfg: DictConfig) -> None:
        super().__init__()
        self.model = make_ff_net(
            inp_dim=np.prod(envs.single_observation_space.shape),
            out_dim=np.prod(envs.single_action_space.shape),
            layers=network_cfg.layers,
            activation_fn=network_cfg.activation_fn,
            final_fn=network_cfg.final_fn,
        )
    
    def forward(self, s) -> torch.Tensor:
        return self.model(s)

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).cuda()
        return self(state).cpu().data.numpy().flatten()

@torch.no_grad()
def eval_policy(policy, env_name, seed, eval_episodes=10):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + 100)

    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            action = policy.select_action(np.array(state))
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward

    avg_reward /= eval_episodes

    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward

@decorate_exception_hook
def train(cfg: DictConfig):
    logger = logging.getLogger(__name__)
    if cfg.wandb.launch:
        wandb_init_wrapper(cfg)
    
    writer = SummaryWriter('summary')
    
    # seeding
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = cfg.torch_deterministic

    device = torch.device(cfg.device)

    # NOTE: SyncVectorEnv takes care of auto-reset on done
    envs = gym.vector.SyncVectorEnv([make_env(cfg.task.name, 0, 0, cfg.capture_video)])
    assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported'

    max_action = float(envs.single_action_space.high[0])
    actor = Actor(envs, cfg.hyperparams.actor).to(device)
    qf1 = QNetwork(envs, cfg.hyperparams.critic).to(device)
    dris = DRIS(cfg.dris)

    logger.info('actor = \n' + str(actor.model))
    logger.info('qf1 = \n' + str(qf1.model))
    logger.info('qf1/dris = ' + str(dris))

    target_actor = Actor(envs, cfg.hyperparams.actor).to(device)
    qf1_target = QNetwork(envs, cfg.hyperparams.critic).to(device)

    target_actor.load_state_dict(actor.state_dict())
    qf1_target.load_state_dict(qf1.state_dict())

    q_optimizer = optim.Adam(list(qf1.parameters()), lr=cfg.hyperparams.lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=cfg.hyperparams.lr)

    rb = ReplayBuffer(
        cfg.hyperparams.buffer.size,
        envs.single_observation_space,
        envs.single_action_space,
        device=device,
        optimize_memory_usage=cfg.hyperparams.buffer.optimize_memory
    )
    timer = {
        'env_interaction': deque([0.0] * 10, maxlen=10),
        'q_update': deque([0.0] * 10, maxlen=10),
        'policy_update': deque([0.0] * 10, maxlen=10),
        'target_network_update': deque([0.0] * 10, maxlen=10)
    }

    start_time = time.time()

    obs = envs.reset()
    mj_sim_list = []
    every_n_eval = cfg.total_timesteps // cfg.total_eval_cnt
    every_n_ckpt = every_n_eval * 2
    for global_step in range(cfg.total_timesteps):
        # ========== START ENV INTERACTION ===============
        env_interaction_start_time = time.time()
        if global_step < cfg.hyperparams.learning_starts:
            actions = envs.action_space.sample()
        else:
            actions = actor.forward(torch.Tensor(obs).to(device))
            actions = np.array(
                [
                    (
                        actions.tolist()[0]
                        + np.random.normal(0, max_action * cfg.hyperparams.exploration_noise, size=envs.action_space.shape[0])
                    ).clip(envs.single_action_space.low, envs.single_action_space.high)
                ]
            )
        
        next_obs, rewards, dones, infos = envs.step(actions)
        mj_sim_list.append(envs.envs[0].sim.get_state())

        for info in infos:
            if 'episode' in info.keys():
                logger.info(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                writer.add_scalar('charts/episodic_return', info['episode']['r'], global_step)
                writer.add_scalar('charts/episodic_length', info['episode']['l'], global_step)
                break
        
        real_next_obs = next_obs.copy()
        for idx, d in enumerate(dones):
            if d:
                real_next_obs[idx] = infos[idx]['terminal_observation']
        rb.add(obs, real_next_obs, actions, rewards, dones, infos)

        obs = next_obs
        timer['env_interaction'].appendleft(time.time() - env_interaction_start_time)
        # ========== END ENV INTERACTION ===============

        if global_step > cfg.hyperparams.learning_starts:
            data = rb.sample(cfg.hyperparams.batch_size)

            # =========== START Q UPDATE ====================
            q_update_start_time = time.time()
            with torch.no_grad():
                next_state_actions = (target_actor.forward(data.next_observations.float())).clamp(
                    envs.single_action_space.low[0], envs.single_action_space.high[0]
                )
                qf1_next_target = qf1_target.forward(data.next_observations.float(), next_state_actions)
                target_qf1 = data.rewards.flatten() + (1 - data.dones.flatten()) * cfg.hyperparams.gamma * (qf1_next_target).view(-1)

            curr_qf1 = qf1(data.observations.float(), data.actions).view(-1)
            
            qf1_loss, qf1_info = dris.loss(curr_qf1, target_qf1)
            dris.update(qf1_info, global_step)
            
            q_optimizer.zero_grad()
            qf1_loss.backward()
            nn.utils.clip_grad_norm_(list(qf1.parameters()), cfg.hyperparams.max_grad_norm)
            q_optimizer.step()
            timer['q_update'].appendleft(time.time() - q_update_start_time)
            # =========== END Q UPDATE ====================

            if global_step % cfg.hyperparams.policy_freq == 0:
                # =========== START Policy UPDATE ====================
                policy_update_start_time = time.time()
                actor_losses = -qf1.forward(data.observations.float(), actor.forward(data.observations.float()))
                actor_loss = dris.update_actor_loss(actor_losses)
                actor_optimizer.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(list(actor.parameters()), cfg.hyperparams.max_grad_norm)
                actor_optimizer.step()
                timer['policy_update'].appendleft(time.time() - policy_update_start_time)
                # =========== END Policy UPDATE ====================

                # update target networks
                target_network_update_start_time = time.time()
                polyak_average(actor.parameters(), target_actor.parameters(), cfg.hyperparams.tau)
                polyak_average(qf1.parameters(), qf1_target.parameters(), cfg.hyperparams.tau)
                timer['target_network_update'].appendleft(time.time() - target_network_update_start_time)
        
            if global_step % cfg.hyperparams.emit_freq == 0:
                writer.add_scalar('losses/qf1_loss', qf1_loss.item(), global_step)
                writer.add_scalar('losses/actor_loss', actor_loss.item(), global_step)
                writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)
                writer.add_scalar('qf1/values', curr_qf1.mean().item(), global_step)
                for key in qf1_info:
                    writer.add_scalar(f'qf1/{key}', qf1_info[key], global_step)

                # record timer
                time_env_interaction = avg_max_50p(timer['env_interaction'])
                time_q_update = avg_max_50p(timer['q_update'])
                time_policy_update = avg_max_50p(timer['policy_update'])
                time_target_network_update = avg_max_50p(timer['target_network_update'])

                writer.add_scalar('timer/env_interaction', time_env_interaction, global_step)
                writer.add_scalar('timer/q_update', time_q_update, global_step)
                writer.add_scalar('timer/policy_update', time_policy_update, global_step)
                writer.add_scalar('timer/target_network_update', time_target_network_update, global_step)
                writer.add_scalar('timer/total_time', 
                    time_env_interaction +
                    time_q_update +
                    time_policy_update +
                    time_target_network_update,
                    global_step
                )
                writer.add_scalar('timer/total_update_time',
                    time_q_update +
                    time_policy_update +
                    time_target_network_update,
                    global_step
                )

            if (global_step+1) % every_n_ckpt == 0:
                os.makedirs(f'checkpoints/{global_step}', exist_ok=True)
                torch.save({
                    'global_step': global_step,
                    'actor_state_dict': actor.state_dict(),
                    'target_state_dict': target_actor.state_dict(),
                    'qf1_state_dict': qf1.state_dict(),
                    'target_qf1_state_dict': qf1_target.state_dict(),
                }, f'checkpoints/{global_step}/checkpoint.pt')
                # store replay observations (used for determining overestimation bias)
                random.shuffle(mj_sim_list)
                io_helper.dump(f'checkpoints/{global_step}/mj_sim.pkl', mj_sim_list[:2000])
                mj_sim_list = []
                with open(f'checkpoints/{global_step}/rb_size.txt', 'w') as f:
                    f.write('%d' % rb.size())
                torch.save({
                    'global_step': global_step,
                    'actor_optimizer_state_dict': actor_optimizer.state_dict(),
                    'q_optimizer_state_dict': q_optimizer.state_dict(),
                }, "latest.pt")
                io_helper.dump("rb.pkl", rb)

            if (global_step+1) % every_n_eval == 0:
                avg_reward = eval_policy(actor, cfg.task.name, 10)
                writer.add_scalar('eval/avg_r', avg_reward)
                with open(f'eval.txt', 'a') as f:
                    f.write(f'\n{global_step:09d}:{avg_reward}')

    envs.close()
    writer.close()
    if cfg.wandb.launch:
        wandb_finish_wrapper()
