import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

from iql.src.util import compute_batched, update_exponential_moving_average
from src.add_lambda_heuristic import get_heuristic_mix_h_v
from src.util import DEFAULT_DEVICE, simple_lambda

EXP_ADV_MAX = 100.


def asymmetric_l2_loss(u, tau):
    return torch.mean(torch.abs(tau - (u < 0).float()) * u**2)


class ImplicitQLearning(nn.Module):
    def __init__(self, qf, vf, policy, gate, optimizer_factory, max_steps,
                 tau, beta, discount=0.99, alpha=0.005, temperature = 1000, method='softmax', gate_threshold = 0):
        super().__init__()
        self.qf = qf.to(DEFAULT_DEVICE)
        self.q_target = copy.deepcopy(qf).requires_grad_(False).to(DEFAULT_DEVICE)
        self.vf = vf.to(DEFAULT_DEVICE)
        self.policy = policy.to(DEFAULT_DEVICE)
        self.gate = gate.to(DEFAULT_DEVICE)
        self.v_optimizer = optimizer_factory(self.vf.parameters())
        self.q_optimizer = optimizer_factory(self.qf.parameters())
        self.policy_optimizer = optimizer_factory(self.policy.parameters())
        self.gate_optimizer = optimizer_factory(self.gate.parameters())
        self.gate_threshold = gate_threshold
        self.policy_lr_schedule = CosineAnnealingLR(self.policy_optimizer, max_steps)
        self.tau = tau
        self.beta = beta
        self.discount = discount
        self.alpha = alpha
        self.temperature = temperature
        self.method = method

    def update(self, observations, actions, next_observations, rewards, terminals,**kwargs):
        with torch.no_grad():
            target_q = self.q_target(observations, actions)
            next_v = self.vf(next_observations)

        # Update value function
        v = self.vf(observations)
        adv = target_q - v

        # debug
        policy_actions = self.policy(observations).mean.detach()
        onestep_q = self.q_target(observations, policy_actions).detach()
        onestep_q_minus_v = onestep_q-v
        onestep_q_minus_q = onestep_q - target_q

        v_loss = asymmetric_l2_loss(adv, self.tau)
        self.v_optimizer.zero_grad(set_to_none=True)
        v_loss.backward()
        self.v_optimizer.step()


        heuristics, mix_hu_v = get_heuristic_mix_h_v(kwargs['returns'], rewards, self.discount,
            self.method, next_v, self.temperature, kwargs['lambda'])
        targets = rewards + (1. - terminals.float()) * self.discount * mix_hu_v

        # Debug
        heuristic_rates = ((heuristics.reshape(-1,1)-next_v.reshape(-1,1).detach())>0).type(torch.float).mean().item()

        qs = self.qf.both(observations, actions)
        q_loss = sum(F.mse_loss(q, targets) for q in qs) / len(qs)
        self.q_optimizer.zero_grad(set_to_none=True)
        q_loss.backward()
        self.q_optimizer.step()


        # Update target Q network
        update_exponential_moving_average(self.q_target, self.qf, self.alpha)

        # Update policy
        exp_adv = torch.exp(self.beta * adv.detach()).clamp(max=EXP_ADV_MAX)


        policy_out = self.policy(observations)


        if isinstance(policy_out, torch.distributions.Distribution):


            bc_losses = -policy_out.log_prob(actions)
            if self.gate_threshold>0:
                policy_actions = policy_out.sample()
                gate_input = torch.cat((torch.concat((observations, policy_actions), 1),
                    torch.concat((observations, actions), 1)),0)

                gate_response = torch.cat((torch.zeros(len(policy_actions)), torch.ones(len(actions)))).reshape(-1,1).to(device=DEFAULT_DEVICE)
                criterion = torch.nn.BCELoss()
                gate_predict = self.gate(gate_input)
                gate_loss = criterion(gate_predict, gate_response)
                self.gate_optimizer.zero_grad(set_to_none=True)
                gate_loss.backward()
                self.gate_optimizer.step()


        elif torch.is_tensor(policy_out):
            assert policy_out.shape == actions.shape
            bc_losses = torch.sum((policy_out - actions)**2, dim=1)
        else:
            raise NotImplementedError


        if self.gate_threshold >0 and self.gate_threshold <1:
            gate_threshold_output = (gate_predict.detach()>self.gate_threshold).float()
            policy_q = self.q_target(observations, policy_out.rsample())
            policy_loss = torch.mean(-gate_threshold_output*policy_q + (1-gate_threshold_output)*bc_losses)

        elif self.gate_threshold == 0:
            policy_loss = torch.mean(exp_adv * bc_losses)
        else:
            raise NotImplementedError

        self.policy_optimizer.zero_grad(set_to_none=True)
        policy_loss.backward()
        self.policy_optimizer.step()
        self.policy_lr_schedule.step()



        info_dict = {
            "V loss": v_loss.item(),
            "Q loss": q_loss.item(),
            "Policy loss": policy_loss.item(),
            "Heuristic rates": heuristic_rates,
            "Average target Q function value": target_q.mean().item(),
            "Average value function value": v.mean().item(),
            "Average heuristic value": heuristics.mean().item(),
            "adv": adv.mean().item(),
            "exp_adv": exp_adv.mean(),
            "one step q minus v": onestep_q_minus_v.mean().item(),
            "q_t(pi) - q_t(a)": onestep_q_minus_q.mean().item(),
        }

        return info_dict

