from copy import deepcopy

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from ml_collections import ConfigDict
from torch import nn as nn

from CQL.SimpleSAC.model import (FullyConnectedQFunction, SamplerPolicy,
                                 Scalar, TanhGaussianPolicy,
                                 soft_target_update)
from CQL.SimpleSAC.utils import prefix_metrics
from src.add_lambda_heuristic import get_heuristic_mix_h_v
from src.util import DEFAULT_DEVICE


class ConservativeSAC(nn.Module):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.alpha_multiplier = 1.0
        config.backup_entropy = False
        config.target_entropy = 0.0
        config.policy_lr = 3e-4
        config.qf_lr = 3e-4
        config.soft_target_update_rate = 5e-3
        config.target_update_period = 1
        config.cql_n_actions = 10
        config.cql_importance_sample = True
        config.cql_target_action_gap = -1.0
        config.cql_temp = 1.0
        config.cql_min_q_weight = 5.0
        config.cql_max_target_backup = False
        config.cql_clip_diff_min = -np.inf
        config.cql_clip_diff_max = np.inf

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, state_dim, action_dim,cql_min_q_weight, temperature, method, discount):
        super().__init__()

        FLAGS = ConfigDict(dict(
            device=DEFAULT_DEVICE,
            batch_size=256,

            reward_scale=1.0,
            reward_bias=0.0,
            clip_action=0.999,

            policy_arch='256-256',
            qf_arch='256-256',
            orthogonal_init=False,
            policy_log_std_multiplier=1.0,
            policy_log_std_offset=-1.0,

            n_epochs=2000,
            n_train_step_per_epoch=1000,
            eval_period=10,
            eval_n_trajs=5,
        ))
        policy = TanhGaussianPolicy(
            state_dim,
            action_dim,
            arch=FLAGS.policy_arch,
            log_std_multiplier=FLAGS.policy_log_std_multiplier,
            log_std_offset=FLAGS.policy_log_std_offset,
            orthogonal_init=FLAGS.orthogonal_init,
        )
        qf1 = FullyConnectedQFunction(
            state_dim,
            action_dim,
            arch=FLAGS.qf_arch,
            orthogonal_init=FLAGS.orthogonal_init,
            )
        qf2 = FullyConnectedQFunction(
            state_dim,
            action_dim,
            arch=FLAGS.qf_arch,
            orthogonal_init=FLAGS.orthogonal_init,
        )

        self.config = ConservativeSAC.get_default_config()
        self.config.cql_min_q_weight = cql_min_q_weight
        self.config.target_entropy = -np.prod(action_dim).item()


        self.policy = policy.to(DEFAULT_DEVICE)
        self.qf1 = qf1.to(DEFAULT_DEVICE)
        self.qf2 = qf2.to(DEFAULT_DEVICE)
        self.target_qf1 = deepcopy(self.qf1)
        self.target_qf2 = deepcopy(self.qf2)
        self.discount = discount
        self.method = method
        self.temperature = temperature

        optimizer_class = torch.optim.Adam

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(), self.config.policy_lr,
        )
        self.qf_optimizer = optimizer_class(
            list(self.qf1.parameters()) + list(self.qf2.parameters()), self.config.qf_lr
        )

        self.log_alpha = Scalar(0.0)
        self.alpha_optimizer = optimizer_class(
            self.log_alpha.parameters(),
            lr=self.config.policy_lr,
        )

        self.update_target_network(1.0)
        self._total_steps = 0
    def update_target_network(self, soft_target_update_rate):
        soft_target_update(self.qf1, self.target_qf1, soft_target_update_rate)
        soft_target_update(self.qf2, self.target_qf2, soft_target_update_rate)

    def update(self, observations, actions, next_observations, rewards, terminals,**kwargs):
        self._total_steps += 1

        dones = terminals.float()

        new_actions, log_pi = self.policy(observations)

        alpha_loss = -(self.log_alpha() * (log_pi + self.config.target_entropy).detach()).mean()
        alpha = self.log_alpha().exp() * self.config.alpha_multiplier


        """ Policy loss """
        q_new_actions = torch.min(
            self.qf1(observations, new_actions),
            self.qf2(observations, new_actions),
        )
        policy_loss = (alpha*log_pi - q_new_actions).mean()

        """ Q function loss """
        q1_pred = self.qf1(observations, actions)
        q2_pred = self.qf2(observations, actions)

        if self.config.cql_max_target_backup:
            new_next_actions, next_log_pi = self.policy(next_observations, repeat=self.config.cql_n_actions)
            target_q_values, max_target_indices = torch.max(
                torch.min(
                    self.target_qf1(next_observations, new_next_actions),
                    self.target_qf2(next_observations, new_next_actions),
                ),
                dim=-1
            )
            next_log_pi = torch.gather(next_log_pi, -1, max_target_indices.unsqueeze(-1)).squeeze(-1)
        else:
            new_next_actions, next_log_pi = self.policy(next_observations)
            target_q_values = torch.min(
                self.target_qf1(next_observations, new_next_actions),
                self.target_qf2(next_observations, new_next_actions),
            )

        if self.config.backup_entropy:
            target_q_values = target_q_values - alpha * next_log_pi



        heuristics, mix_hu_v = get_heuristic_mix_h_v(kwargs['returns'], rewards, self.discount,
            self.method, target_q_values, self.temperature, kwargs['lambda'])

        td_target = rewards + (1. - dones) * self.discount * mix_hu_v


        qf1_loss = F.mse_loss(q1_pred, td_target.detach())
        qf2_loss = F.mse_loss(q2_pred, td_target.detach())


        ### CQL
        batch_size = actions.shape[0]
        action_dim = actions.shape[-1]
        cql_random_actions = actions.new_empty((batch_size, self.config.cql_n_actions, action_dim), requires_grad=False).uniform_(-1, 1)
        cql_current_actions, cql_current_log_pis = self.policy(observations, repeat=self.config.cql_n_actions)
        cql_next_actions, cql_next_log_pis = self.policy(next_observations, repeat=self.config.cql_n_actions)
        cql_current_actions, cql_current_log_pis = cql_current_actions.detach(), cql_current_log_pis.detach()
        cql_next_actions, cql_next_log_pis = cql_next_actions.detach(), cql_next_log_pis.detach()

        cql_q1_rand = self.qf1(observations, cql_random_actions)
        cql_q2_rand = self.qf2(observations, cql_random_actions)
        cql_q1_current_actions = self.qf1(observations, cql_current_actions)
        cql_q2_current_actions = self.qf2(observations, cql_current_actions)
        cql_q1_next_actions = self.qf1(observations, cql_next_actions)
        cql_q2_next_actions = self.qf2(observations, cql_next_actions)

        cql_cat_q1 = torch.cat(
            [cql_q1_rand, torch.unsqueeze(q1_pred, 1), cql_q1_next_actions, cql_q1_current_actions], dim=1
        )
        cql_cat_q2 = torch.cat(
            [cql_q2_rand, torch.unsqueeze(q2_pred, 1), cql_q2_next_actions, cql_q2_current_actions], dim=1
        )
        cql_std_q1 = torch.std(cql_cat_q1, dim=1)
        cql_std_q2 = torch.std(cql_cat_q2, dim=1)

        if self.config.cql_importance_sample:
            random_density = np.log(0.5 ** action_dim)
            cql_cat_q1 = torch.cat(
                [cql_q1_rand - random_density,
                 cql_q1_next_actions - cql_next_log_pis.detach(),
                 cql_q1_current_actions - cql_current_log_pis.detach()],
                dim=1
            )
            cql_cat_q2 = torch.cat(
                [cql_q2_rand - random_density,
                 cql_q2_next_actions - cql_next_log_pis.detach(),
                 cql_q2_current_actions - cql_current_log_pis.detach()],
                dim=1
            )

        cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.config.cql_temp, dim=1) * self.config.cql_temp
        cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.config.cql_temp, dim=1) * self.config.cql_temp

        """Subtract the log likelihood of data"""
        cql_qf1_diff = torch.clamp(
            cql_qf1_ood - q1_pred,
            self.config.cql_clip_diff_min,
            self.config.cql_clip_diff_max,
        ).mean()
        cql_qf2_diff = torch.clamp(
            cql_qf2_ood - q2_pred,
            self.config.cql_clip_diff_min,
            self.config.cql_clip_diff_max,
        ).mean()


        cql_min_qf1_loss = cql_qf1_diff * self.config.cql_min_q_weight
        cql_min_qf2_loss = cql_qf2_diff * self.config.cql_min_q_weight
        alpha_prime_loss = observations.new_tensor(0.0)
        alpha_prime = observations.new_tensor(0.0)


        qf_loss = qf1_loss + qf2_loss + cql_min_qf1_loss + cql_min_qf2_loss


        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        if self.total_steps % self.config.target_update_period == 0:
            self.update_target_network(
                self.config.soft_target_update_rate
            )


        metrics = dict(
            log_pi=log_pi.mean().item(),
            policy_loss=policy_loss.item(),
            qf1_loss=qf1_loss.item(),
            qf2_loss=qf2_loss.item(),
            alpha_loss=alpha_loss.item(),
            alpha=alpha.item(),
            average_qf1=q1_pred.mean().item(),
            average_qf2=q2_pred.mean().item(),
            average_target_q=target_q_values.mean().item(),
            total_steps=self.total_steps,
        )

        metrics.update(prefix_metrics(dict(
            cql_std_q1=cql_std_q1.mean().item(),
            cql_std_q2=cql_std_q2.mean().item(),
            cql_q1_rand=cql_q1_rand.mean().item(),
            cql_q2_rand=cql_q2_rand.mean().item(),
            cql_min_qf1_loss=cql_min_qf1_loss.mean().item(),
            cql_min_qf2_loss=cql_min_qf2_loss.mean().item(),
            cql_qf1_diff=cql_qf1_diff.mean().item(),
            cql_qf2_diff=cql_qf2_diff.mean().item(),
            cql_q1_current_actions=cql_q1_current_actions.mean().item(),
            cql_q2_current_actions=cql_q2_current_actions.mean().item(),
            cql_q1_next_actions=cql_q1_next_actions.mean().item(),
            cql_q2_next_actions=cql_q2_next_actions.mean().item(),
            alpha_prime_loss=alpha_prime_loss.item(),
            alpha_prime=alpha_prime.item(),
            heuristic_rates = ((heuristics.reshape(-1,1)-target_q_values.reshape(-1,1).detach())>0).type(torch.float).mean().item()

        ), 'cql'))



        return metrics

    def torch_to_device(self, device):
        for module in self.modules:
            module.to(device)
    def select_action(self, state):
        # This is to generate only one action. To make the diemnsion match.
        state = torch.FloatTensor(state.reshape(1, -1)).to(DEFAULT_DEVICE)
        new_actions = self.policy(state,deterministic=True)[0].cpu().data.numpy().flatten()
        return new_actions
    @property
    def modules(self):
        modules = [self.policy, self.qf1, self.qf2, self.target_qf1, self.target_qf2]
        modules.append(self.log_alpha)
        return modules

    @property
    def total_steps(self):
        return self._total_steps
