import argparse
import copy
import math
import os
from itertools import chain

import numpy as np
import tensorboardX
import torch
import torch.nn.functional as F
import tqdm

from rl_utils import envs, nets, run, utils, device
import replay


def learn_critics(member, buffer, batch_size, gamma):

    agent = member.agent
    batch = buffer.sample(batch_size)

    # prepare transitions for models
    state_batch, action_batch, reward_batch, next_state_batch, done_batch = batch
    state_batch = state_batch.to(device)
    next_state_batch = next_state_batch.to(device)
    action_batch = action_batch.to(device)
    reward_batch = reward_batch.to(device)
    done_batch = done_batch.to(device)

    agent.train()

    ###################
    ## CRITIC UPDATE ##
    ###################
    alpha = torch.exp(agent.log_alpha)
    with torch.no_grad():
        action_dist_s1 = agent.actor(next_state_batch)
        action_s1 = action_dist_s1.rsample()
        logp_a1 = action_dist_s1.log_prob(action_s1).sum(-1, keepdim=True)
        y1 = agent.critic1(next_state_batch, action_s1)
        y2 = agent.critic2(next_state_batch, action_s1)
        clipped_double_q_s1 = torch.min(y1, y2)
        gammas = (
            torch.Tensor([gamma ** i for i in range(reward_batch.shape[-1])])
            .unsqueeze(0)
            .to(device)
        )
        discounted_rews = (gammas * reward_batch).sum(1, keepdim=True)
        action_repeat = state_batch[:, 0]
        multistep_gamma = (gamma ** action_repeat).unsqueeze(1)
        td_target = discounted_rews + multistep_gamma * (1.0 - done_batch) * (
            clipped_double_q_s1 - (alpha * logp_a1)
        )

    # standard bellman error
    a_critic1_pred = agent.critic1(state_batch, action_batch)
    a_critic2_pred = agent.critic2(state_batch, action_batch)
    td_error1 = td_target - a_critic1_pred
    td_error2 = td_target - a_critic2_pred

    # constraints that discourage large changes in Q(s_{t+1}, a_{t+1}),
    a1_critic1_pred = agent.critic1(next_state_batch, action_s1)
    a1_critic2_pred = agent.critic2(next_state_batch, action_s1)
    a1_constraint1 = y1 - a1_critic1_pred
    a1_constraint2 = y2 - a1_critic2_pred

    elementwise_critic_loss = (
        (td_error1 ** 2)
        + (td_error2 ** 2)
        + (a1_constraint1 ** 2)
        + (a1_constraint2 ** 2)
    )
    critic_loss = 0.5 * elementwise_critic_loss.mean()
    agent.critic_optimizer.zero_grad()
    critic_loss.backward()
    agent.critic_optimizer.step()


def learn_actor(
    member, buffer, batch_size, target_entropy_mul,
):
    agent = member.agent
    agent.train()

    batch = buffer.sample(batch_size)
    s, a, *_ = batch
    s = s.to(device)
    a = a.to(device)

    dist = agent.actor(s)
    agent_a = dist.rsample()
    logp_a = dist.log_prob(agent_a).sum(-1, keepdim=True)
    vals = torch.min(agent.critic1(s, agent_a), agent.critic2(s, agent_a))
    entropy_bonus = agent.log_alpha.exp() * logp_a
    actor_loss = -(vals - entropy_bonus).mean()

    optimizer = agent.online_actor_optimizer
    optimizer.zero_grad()
    actor_loss.backward()
    optimizer.step()

    ##################
    ## ALPHA UPDATE ##
    ##################
    target_entropy = target_entropy_mul * -float(a.shape[1])
    alpha_loss = (-agent.log_alpha.exp() * (logp_a + target_entropy).detach()).mean()
    agent.log_alpha_optimizer.zero_grad()
    alpha_loss.backward()
    agent.log_alpha_optimizer.step()
