import copy
import torch
import torch.nn.functional as F

from pex.utils.util import (DEFAULT_DEVICE, epsilon_greedy_sample,
                            extract_sub_dict)
from pex.utils.util import update_exponential_moving_average

from pex.algorithms.iql import IQL, EXP_ADV_MAX


def expectile_loss(diff: torch.Tensor, expectile: torch.Tensor):
    weight = torch.where(diff > 0, expectile, 1 - expectile)
    return (weight * diff.pow(2)).mean()
# def expectile_loss(diff, expectile):
#     weight = torch.where(diff > 0, expectile, (1 - expectile))
#     return (weight * (diff**2)).mean()
EXP_ADV_MAX = 100.

class OUR_PQV(IQL):
    def __init__(self, critic, vf, policy, optimizer_ctor, alpha,
                 tau, beta, discount, target_update_rate, ckpt_path, inv_temperature, copy_to_target=True):

        super().__init__(critic=critic, vf=vf, policy=policy,
                         optimizer_ctor=optimizer_ctor,
                         max_steps=None,
                         tau=tau, beta=beta,
                         discount=discount,
                         target_update_rate=target_update_rate,
                         use_lr_scheduler=False)

        self.policy_offline = copy.deepcopy(self.policy).to(DEFAULT_DEVICE)
        self._inv_temperature = inv_temperature
        # load checkpoint if ckpt_path is not None
        if ckpt_path is not None:

            map_location = None
            if not torch.cuda.is_available():
                map_location = torch.device('cpu')
            checkpoint = torch.load(ckpt_path, map_location=map_location)

            # extract sub-dictionary
            policy_state_dict = extract_sub_dict("policy", checkpoint)
            critic_state_dict = extract_sub_dict("critic", checkpoint)
            
            self.policy_offline.load_state_dict(policy_state_dict)
            self.policy.load_state_dict(policy_state_dict)
            self.critic.load_state_dict(critic_state_dict)
            if copy_to_target:
                self.target_critic.load_state_dict(critic_state_dict)
            else:
                target_critic_state_dict = extract_sub_dict("target_critic", checkpoint)
                self.target_critic.load_state_dict(target_critic_state_dict)

            self.vf.load_state_dict(extract_sub_dict("vf", checkpoint))
        self.alpha = alpha

    def update(self, observations, actions, next_observations, rewards, terminals):

        with torch.no_grad():
            target_q = self.target_critic.min(observations, actions)
            next_v = self.vf(next_observations)
            
            next_actions = self.select_action(next_observations[128:, :])
            next_q = self.target_critic.min(next_observations[128:, :], next_actions)
        # Update value function
        v = self.vf(observations)
        adv = target_q.detach() - v
        #v_loss = expectile_loss(adv, self.tau)
        tau_vector = torch.tensor([self.tau] * 256).to(observations.device)
        v_loss = expectile_loss(adv, tau_vector)
        #v_loss = expectile_loss(adv[:128], self.tau)
        #v_loss = (v_loss + adv[128:].mean()) / 2
        self.v_optimizer.zero_grad(set_to_none=True)
        v_loss.backward()
        self.v_optimizer.step()

        # Update Q function
        targets = rewards[:128] + (1. - terminals[:128].float()) * self.discount * next_v[:128].detach()
        qs = self.critic(observations[:128,:], actions[:128,:])
        q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs)
        # TD error
        targets_td = rewards[128:] + (1. - terminals[128:].float()) * self.discount * next_q.detach()
        qs_td = self.critic(observations[128:, :], actions[128:, :])
        q_loss_td = sum(F.mse_loss(q, targets_td) for q in qs_td) / len(qs_td)
        
        
        q_loss = (q_loss + q_loss_td) / 2
        self.q_optimizer.zero_grad(set_to_none=True)
        q_loss.backward()
        self.q_optimizer.step()

        # Update target Q network
        update_exponential_moving_average(self.target_critic, self.critic, self.target_update_rate)

        self.policy_update(observations, adv, actions)

    def policy_update(self, observations, adv, actions):

        exp_adv_off = torch.exp(self.beta * adv[:128].detach()).clamp(max=EXP_ADV_MAX)
        policy_out = self.policy(observations[:128])
        bc_losses = -policy_out.log_prob(actions[:128].detach())
        L_iql = torch.mean(exp_adv_off * bc_losses)

        # -------- SAC term on ONLINE half --------
        policy_on = self.policy(observations[128:])
        a_pi = policy_on.rsample()                        # reparameterized sample
        logp_pi = policy_on.log_prob(a_pi)
        q1_on, q2_on = self.critic(observations[128:], a_pi)
        q_pi_on = torch.min(q1_on, q2_on)
        
        L_sac = torch.mean(self.alpha * logp_pi - q_pi_on)
        policy_loss = 0.5 * L_iql + 0.5 * L_sac

        self.policy_optimizer.zero_grad(set_to_none=True)
        policy_loss.backward()
        self.policy_optimizer.step()

        if self.use_lr_scheduler:
            self.policy_lr_schedule.step()
    
