import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

from pex.utils.util import DEFAULT_DEVICE, update_exponential_moving_average
from torchmetrics.regression import LogCoshError

EXP_ADV_MAX = 100.
def huber_loss(diff, sigma=1):
    beta = 1. / (sigma ** 2)
    diff = torch.abs(diff)
    cond = diff < beta
    loss = torch.where(cond, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
    return loss

def expectile_loss(diff, expectile):
    weight = torch.where(diff > 0, expectile, (1 - expectile))
    return (weight * (diff**2)).mean()

class RIQL(nn.Module):
    def __init__(self, critic, vf, policy, optimizer_ctor, max_steps,
                 tau, beta, online_policy=None,sigma=3.0, quantile=0.1, discount=0.99, target_update_rate=0.005, use_lr_scheduler=True):
        super().__init__()
        self.critic = critic.to(DEFAULT_DEVICE)
        self.target_critic = copy.deepcopy(critic).requires_grad_(False).to(DEFAULT_DEVICE)
        self.vf = vf.to(DEFAULT_DEVICE)
        self.policy = policy.to(DEFAULT_DEVICE)
        self.online_policy = online_policy
        self.v_optimizer = optimizer_ctor(self.vf.parameters())
        self.q_optimizer = optimizer_ctor(self.critic.parameters())
        self.policy_optimizer = optimizer_ctor(self.policy.parameters())
        self.tau = tau
        self.beta = beta
        self.sigma = sigma
        self.quantile = quantile
        self.discount = discount
        self.target_update_rate = target_update_rate
        self.use_lr_scheduler = use_lr_scheduler
        self.LogCoshLoss = LogCoshError()
        if use_lr_scheduler:
            self.policy_lr_schedule = CosineAnnealingLR(self.policy_optimizer, max_steps)

    def update(self, observations, actions, next_observations, rewards, terminals):

        with torch.no_grad():
            target_q_all = self.target_critic(observations, actions)
            target_q = torch.quantile(target_q_all.detach(),self.quantile,dim=0)
 
            
            target_q_std = target_q_all.detach().std(dim=0)
            target_diff = target_q_all.detach().mean(dim=0) - target_q
            next_v = self.vf(next_observations)

        # Update value function
        v = self.vf(observations)
        adv = target_q.detach() - v
        v_loss = expectile_loss(adv, self.tau)
        self.v_optimizer.zero_grad(set_to_none=True)
        v_loss.backward()
        self.v_optimizer.step()

        # Update Q function
        targets = rewards + (1. - terminals.float()) * self.discount * next_v.detach()
        qs = self.critic(observations, actions)
        
        
        # target clipping
        targets = torch.clamp(targets, -100, 1000).view(1, targets.shape[0])
        q_loss = sum(huber_loss(targets.detach() - q, self.sigma).mean() for q in qs)/len(qs)
        
        # q_loss = sum(self.LogCoshLoss(targets.detach(), q) for q in qs)/len(qs)

        
        self.q_optimizer.zero_grad(set_to_none=True)
        q_loss.backward()
        self.q_optimizer.step()

        # Update target Q network
        update_exponential_moving_average(self.target_critic, self.critic, self.target_update_rate)
        # print(actions.shape)
        self.policy_update(observations, adv, actions)

    def policy_update(self, observations, adv, actions):
        exp_adv = torch.exp(self.beta * adv.detach()).clamp(max=EXP_ADV_MAX)
        policy_out = self.policy(observations)
        
        if isinstance(policy_out, torch.distributions.Distribution):
            bc_losses = -policy_out.log_prob(actions.detach())
        elif torch.is_tensor(policy_out):
            bc_losses = torch.sum((policy_out - actions) ** 2, dim=-1)
        else:
            raise NotImplementedError

        policy_loss = torch.mean(exp_adv * bc_losses)
        self.policy_optimizer.zero_grad(set_to_none=True)
        policy_loss.backward()
        self.policy_optimizer.step()
        if self.use_lr_scheduler:
            self.policy_lr_schedule.step()


    def select_action(self, state, evaluate=False):
        policy_out = self.policy(state)
        
        if evaluate is False:
            if isinstance(policy_out, torch.distributions.Distribution):
                action_sample, _, _ = self.policy.sample(state)
            elif torch.is_tensor(policy_out):
                action_sample = self.policy.act(state)

            return action_sample
        else:
            if isinstance(policy_out, torch.distributions.Distribution):
                _, _, action_mode = self.policy.sample(state)
            elif torch.is_tensor(policy_out):
                action_mode = self.policy.act(state)
            
            return action_mode
