import logging
import time
import os
 
import gym
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from typing import Deque
from collections import deque
from omegaconf import DictConfig
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter
from dris.utils import avg_max_50p, make_env, make_ff_net, polyak_average, wandb_init_wrapper, wandb_finish_wrapper
from dris.dris import DRIS
import THIRD_PARTY.io as io_helper
from THIRD_PARTY import decorate_hook
class QNetwork(nn.Module):
    def __init__(self, envs: gym.vector.SyncVectorEnv, network_cfg: DictConfig) -> None:
        super().__init__()
        self.model = make_ff_net(
            inp_dim=np.prod(envs.single_observation_space.shape) + np.prod(envs.single_action_space.shape),
            out_dim=1,
            layers=network_cfg.layers,
            activation_fn=network_cfg.activation_fn,
            final_fn=network_cfg.final_fn,
        )

    def forward(self, s, a) -> torch.Tensor:
        x = torch.cat([s, a], 1)
        return self.model(x)

class Actor(nn.Module):
    def __init__(self, envs: gym.vector.SyncVectorEnv, network_cfg: DictConfig) -> None:
        super().__init__()
        self.model = make_ff_net(
            inp_dim=np.prod(envs.single_observation_space.shape),
            out_dim=np.prod(envs.single_action_space.shape),
            layers=network_cfg.layers,
            activation_fn=network_cfg.activation_fn,
            final_fn=network_cfg.final_fn,
        )
    
    def forward(self, s) -> torch.Tensor:
        return self.model(s)

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).cuda()
        return self(state).cpu().data.numpy().flatten()

@torch.no_grad()
def eval_policy(policy, env_name, seed, eval_episodes=10):
    eval_env = gym.make(env_name)
    eval_env.seed(seed + 100)

    avg_reward = 0.
    for _ in range(eval_episodes):
        state, done = eval_env.reset(), False
        while not done:
            action = policy.select_action(np.array(state))
            state, reward, done, _ = eval_env.step(action)
            avg_reward += reward

    avg_reward /= eval_episodes

    print("---------------------------------------")
    print(f"Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    return avg_reward

def train(cfg: DictConfig):
    logger = logging.getLogger(__name__)
    if cfg.wandb.launch:
        wandb_init_wrapper(cfg)
    
    writer = SummaryWriter('summary')
    
    # seeding
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.backends.cudnn.deterministic = cfg.torch_deterministic

    device = torch.device(cfg.device)

    # NOTE: SyncVectorEnv takes care of auto-reset on done
    envs = gym.vector.SyncVectorEnv([make_env(cfg.task.name, 0, 0, cfg.capture_video)])
    assert isinstance(envs.single_action_space, gym.spaces.Box), 'only continuous action space is supported'

    max_action = float(envs.single_action_space.high[0])
    actor = Actor(envs, cfg.hyperparams.actor).to(device)
    qf1 = QNetwork(envs, cfg.hyperparams.critic).to(device)
    qf2 = QNetwork(envs, cfg.hyperparams.critic).to(device)
    dris1 = DRIS(cfg.dris)
    dris2 = DRIS(cfg.dris)

    logger.info('actor = \n' + str(actor.model))
    logger.info('qf1 = \n' + str(qf1.model))
    logger.info('qf1/dris1 = ' + str(dris1))
    logger.info('qf2 = \n' + str(qf1.model))
    logger.info('qf2/dris2 = ' + str(dris2))

    target_actor = Actor(envs, cfg.hyperparams.actor).to(device)
    qf1_target = QNetwork(envs, cfg.hyperparams.critic).to(device)
    qf2_target = QNetwork(envs, cfg.hyperparams.critic).to(device)

    target_actor.load_state_dict(actor.state_dict())
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())

    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=cfg.hyperparams.lr)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=cfg.hyperparams.lr)

    envs.single_observation_space.dtype = np.float32
    rb = ReplayBuffer(
        cfg.hyperparams.buffer.size,
        envs.single_observation_space,
        envs.single_action_space,
        device=device,
        optimize_memory_usage=cfg.hyperparams.buffer.optimize_memory,
        handle_timeout_termination=True,
    )
    timer = {
        'env_interaction': deque([0.0] * 10, maxlen=10),
        'q_update': deque([0.0] * 10, maxlen=10),
        'policy_update': deque([0.0] * 10, maxlen=10),
        'target_network_update': deque([0.0] * 10, maxlen=10)
    }

    start_time = time.time()

    obs = envs.reset()
    mj_sim_list = []
    every_n_eval = cfg.total_timesteps // cfg.total_eval_cnt
    every_n_ckpt = every_n_eval * 2
    for global_step in range(cfg.total_timesteps):
        # ========== START ENV INTERACTION ===============
        env_interaction_start_time = time.time()
        if global_step < cfg.hyperparams.learning_starts:
            actions = envs.action_space.sample()
        else:
            actions = actor.forward(torch.Tensor(obs).to(device))
            actions = np.array(
                [
                    (
                        actions.tolist()[0]
                        + np.random.normal(0, max_action * cfg.hyperparams.exploration_noise, size=envs.action_space.shape[0])
                    ).clip(envs.single_action_space.low, envs.single_action_space.high)
                ]
            )
        
        next_obs, rewards, dones, infos = envs.step(actions)
        mj_sim_list.append(envs.envs[0].sim.get_state())

        for info in infos:
            if 'episode' in info.keys():
                logger.info(f"global_step={global_step}, episodic_return={info['episode']['r']}")
                writer.add_scalar('charts/episodic_return', info['episode']['r'], global_step)
                writer.add_scalar('charts/episodic_length', info['episode']['l'], global_step)
                break
        
        real_next_obs = next_obs.copy()
        for idx, d in enumerate(dones):
            if d:
                real_next_obs[idx] = infos[idx]['terminal_observation']
        rb.add(obs, real_next_obs, actions, rewards, dones, infos)

        obs = next_obs
        timer['env_interaction'].appendleft(time.time() - env_interaction_start_time)
        # ========== END ENV INTERACTION ===============

        if global_step > cfg.hyperparams.learning_starts:
            data = rb.sample(cfg.hyperparams.batch_size)

            # =========== START Q UPDATE ====================
            q_update_start_time = time.time()
            with torch.no_grad():
                clipped_noise = (torch.randn_like(torch.Tensor(actions[0])) * cfg.hyperparams.policy_noise).clamp(
                    -cfg.hyperparams.noise_clip, cfg.hyperparams.noise_clip
                )
                next_state_actions = (target_actor(data.next_observations) + clipped_noise.to(device)).clamp(
                    envs.single_action_space.low[0], envs.single_action_space.high[0]
                )
                qf1_next_target = qf1_target(data.next_observations, next_state_actions)
                qf2_next_target = qf2_target(data.next_observations, next_state_actions)
                min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * cfg.hyperparams.gamma * (min_qf_next_target).view(-1)

            qf1_a_values = qf1(data.observations, data.actions).view(-1)
            qf2_a_values = qf2(data.observations, data.actions).view(-1)
            
            qf1_loss, qf1_info = dris1.loss(qf1_a_values, next_q_value)
            qf2_loss, qf2_info = dris2.loss(qf2_a_values, next_q_value)

            qf_loss = qf1_loss + qf2_loss

            dris1.update(qf1_info, global_step)
            dris2.update(qf2_info, global_step)
            
            q_optimizer.zero_grad()
            qf_loss.backward()
            nn.utils.clip_grad_norm_(list(qf1.parameters()) + list(qf2.parameters()), cfg.hyperparams.max_grad_norm)
            q_optimizer.step()
            timer['q_update'].appendleft(time.time() - q_update_start_time)
            # =========== END Q UPDATE ====================

            if global_step % cfg.hyperparams.policy_freq == 0:
                # =========== START Policy UPDATE ====================
                policy_update_start_time = time.time()
                actor_losses = -qf1.forward(data.observations.float(), actor.forward(data.observations.float()))
                actor_loss = dris1.update_actor_loss(actor_losses)
                actor_optimizer.zero_grad()
                actor_loss.backward()
                nn.utils.clip_grad_norm_(list(actor.parameters()), cfg.hyperparams.max_grad_norm)
                actor_optimizer.step()
                timer['policy_update'].appendleft(time.time() - policy_update_start_time)
                # =========== END Policy UPDATE ====================

                # update target networks
                target_network_update_start_time = time.time()
                polyak_average(actor.parameters(), target_actor.parameters(), cfg.hyperparams.tau)
                polyak_average(qf1.parameters(), qf1_target.parameters(), cfg.hyperparams.tau)
                polyak_average(qf2.parameters(), qf2_target.parameters(), cfg.hyperparams.tau)
                timer['target_network_update'].appendleft(time.time() - target_network_update_start_time)
        
            if global_step % cfg.hyperparams.emit_freq == 0:
                writer.add_scalar('losses/qf1_loss', qf1_loss.item(), global_step)
                writer.add_scalar('losses/qf2_loss', qf2_loss.item(), global_step)
                writer.add_scalar('losses/qf_loss', qf_loss.item(), global_step)
                writer.add_scalar('losses/actor_loss', actor_loss.item(), global_step)
                writer.add_scalar('charts/SPS', int(global_step / (time.time() - start_time)), global_step)

                writer.add_scalar('qf1/values', qf1_a_values.mean().item(), global_step)
                for key in qf1_info:
                    writer.add_scalar(f'qf1/{key}', qf1_info[key], global_step)
                writer.add_scalar('qf2/values', qf2_a_values.mean().item(), global_step)
                for key in qf2_info:
                    writer.add_scalar(f'qf2/{key}', qf2_info[key], global_step)

                # record timer
                time_env_interaction = avg_max_50p(timer['env_interaction'])
                time_q_update = avg_max_50p(timer['q_update'])
                time_policy_update = avg_max_50p(timer['policy_update'])
                time_target_network_update = avg_max_50p(timer['target_network_update'])

                writer.add_scalar('timer/env_interaction', time_env_interaction, global_step)
                writer.add_scalar('timer/q_update', time_q_update, global_step)
                writer.add_scalar('timer/policy_update', time_policy_update, global_step)
                writer.add_scalar('timer/target_network_update', time_target_network_update, global_step)
                writer.add_scalar('timer/total_time',
                    time_env_interaction +
                    time_q_update +
                    time_policy_update +
                    time_target_network_update,
                    global_step
                )
                writer.add_scalar('timer/total_update_time',
                    time_q_update +
                    time_policy_update +
                    time_target_network_update,
                    global_step
                )
            if (global_step+1) % every_n_ckpt == 0:
                os.makedirs(f'checkpoints/{global_step}', exist_ok=True)
                torch.save({
                    'global_step': global_step,
                    'actor_state_dict': actor.state_dict(),
                    'target_state_dict': target_actor.state_dict(),
                    'qf1_state_dict': qf1.state_dict(),
                    'qf2_state_dict': qf2.state_dict(),
                    'target_qf1_state_dict': qf1_target.state_dict(),
                    'target_qf2_state_dict': qf2_target.state_dict(),
                }, f'checkpoints/{global_step}/checkpoint.pt')
                # store replay observations (used for determining overestimation bias)
                random.shuffle(mj_sim_list)
                io_helper.dump(f'checkpoints/{global_step}/mj_sim.pkl', mj_sim_list[:2000])
                mj_sim_list = []
                with open(f'checkpoints/{global_step}/rb_size.txt', 'w') as f:
                    f.write('%d' % rb.size())
                torch.save({
                    'global_step': global_step,
                    'actor_optimizer_state_dict': actor_optimizer.state_dict(),
                    'q_optimizer_state_dict': q_optimizer.state_dict(),
                }, "latest.pt")
                io_helper.dump("rb.pkl", rb)

            if (global_step+1) % every_n_eval == 0:
                avg_reward = eval_policy(actor, cfg.task.name, 10)
                writer.add_scalar('eval/avg_r', avg_reward)
                with open(f'eval.txt', 'a') as f:
                    f.write(f'\n{global_step:09d}:{avg_reward}')

    envs.close()
    writer.close()
    if cfg.wandb.launch:
        wandb_finish_wrapper()
