"""
SAC Implementation based on https://github.com/pranz24/pytorch-soft-actor-critic/blob/master/sac.py
"""

import mcac.algos.core as core
import mcac.utils.pytorch_utils as ptu

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import url_benchmark.utils as utils

import copy
import os

import numpy as np


class SAC_SIG:
    def __init__(self, params, ssg_module):
        
        self.l0 = 0.0

        self.tau = params['tau']
        self.alpha = params['alpha']
        self.max_action = params['max_action']
        self.discount = params['discount']
        self.batch_size = params['batch_size']
        self.do_mcac_bonus = params['do_mcac_bonus']
        self.total_it = 0
        self.running_risk = 1

        self.policy_type = params['policy']
        self.target_update_interval = params['target_update_interval']
        self.automatic_entropy_tuning = params['automatic_entropy_tuning']

        self.critic = core.Critic(params['d_obs'], params['d_act'],
                                  ensemble_size=params['q_ensemble_size']).to(ptu.TORCH_DEVICE)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=params['lr'])

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning is True:
                self.target_entropy = -torch.prod(torch.Tensor(params['d_act'])
                                                  .to(ptu.TORCH_DEVICE)).item()
                self.log_alpha = torch.zeros(1, requires_grad=True, device=ptu.TORCH_DEVICE)
                self.alpha_optim = Adam([self.log_alpha], lr=params['lr'])

            self.policy = core.GaussianPolicy(params['d_obs'], params['d_act'],
                                              params['hidden_size'], params['max_action']) \
                .to(ptu.TORCH_DEVICE)
            self.policy_optim = Adam(self.policy.parameters(), lr=params['lr'])

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = core.DeterministicPolicy(params['d_obs'], params['d_act'],
                                                   params['hidden_size'], params['max_action']) \
                .to(ptu.TORCH_DEVICE)
            self.policy_optim = Adam(self.policy.parameters(), lr=params['lr'])

        # 关于ICM的设置
        self.ssg_scale = 0.3
        self.ssg = ssg_module
        # optimizers
        self.ssg_opt = torch.optim.Adam(self.ssg.parameters(), lr=params['lr'])
        self.ssg.train()

    def update_ssg(self, obs, action, next_obs, reward, nowtimestep):

        forward_error, backward_error, reward_error = self.ssg(obs, action, next_obs, reward)

        forward_loss = (torch.mean(forward_error)).item()
        if nowtimestep % 1000 == 0:
            print("=======================================================================")
            print("forward_loss is {}".format(forward_loss))

        backward_loss = (torch.mean(backward_error)).item()
        if nowtimestep % 1000 == 0:
            print("backward_loss is {}".format(backward_loss))

        reward_loss = (torch.mean(reward_error)).item()
        if nowtimestep % 1000 == 0:
            print("reward_loss is {}".format(reward_loss))

        loss = forward_error.mean() + backward_error.mean() + reward_error.mean()

        self.ssg_opt.zero_grad()
        loss.backward()
        self.ssg_opt.step()

        return loss

    def compute_intr_reward(self, obs, action, next_obs, reward_useless):
        forward_error, _, _ = self.ssg(obs, action, next_obs, reward_useless)
        reward = forward_error * self.ssg_scale
        reward = torch.log(reward + 1.0)
        return reward

    def select_action(self, state, evaluate=False):
        state = ptu.torchify(state).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0] * self.max_action

    def select_action_batch(self, states, evaluate=False):
        states = ptu.torchify(states)
        if evaluate is False:
            action, _, _ = self.policy.sample(states)
        else:
            _, _, action = self.policy.sample(states)
        return action.detach() * self.max_action

    def update(self, replay_buffer, replay_buffer_gen, nowtimestep, loss_record):


        if nowtimestep == 0:
            self.l0 = loss_record
            rb2_ratio = 0
            rb1_ratio = 1
        else:
            x1 = self.l0
            x2 = 0.0
            y1 = 0
            y2 = 0.1
            # self
            a = (y1 - y2) / (x1 - x2)
            b = -x1 * a
            rb2_ratio = a * loss_record + b
            if rb2_ratio > y2:
                rb2_ratio = y2
            elif rb2_ratio < 0:
                rb2_ratio = 0
            rb1_ratio = 1 - rb2_ratio


        if nowtimestep % 1000 == 0:
            print("rb1 len is {}".format(len(replay_buffer)))
            print("rb2_img len is {}".format(len(replay_buffer_gen)))
        #print("use dream")

        out_dict1 = replay_buffer.sample(int(self.batch_size * rb1_ratio))

        if int(self.batch_size * rb2_ratio) == 0:
            out_dict2 = {}
        else:
            out_dict2 = replay_buffer_gen.sample(int(self.batch_size * rb2_ratio))

        if out_dict2 == {}:
            out_dict = out_dict1
        else:

            batch3 = {}

            for key in out_dict1:
                if isinstance(out_dict1[key], np.ndarray):

                    batch3[key] = np.concatenate([out_dict1[key], out_dict2[key]])
                else:

                    batch3[key] = out_dict1[key] + out_dict2[key]
            out_dict = batch3


        obs, action, next_obs, reward, \
        mask, drtg, expert, succ = out_dict['obs'], out_dict['act'], \
                                    out_dict['next_obs'], out_dict['rew'], \
                                    out_dict['mask'], out_dict['drtg'], \
                                    out_dict['expert'], out_dict['succ']

        obs, action, next_obs, reward, mask, drtg, expert, succ = \
            ptu.torchify(obs, action, next_obs, reward, mask, drtg, expert, succ)

        # =====================================ICM============================================

        obs1, action1, next_obs1, reward1, \
        mask1, drtg1, expert1, succ1 = out_dict1['obs'], out_dict1['act'], \
                                    out_dict1['next_obs'], out_dict1['rew'], \
                                    out_dict1['mask'], out_dict1['drtg'], \
                                    out_dict1['expert'], out_dict1['succ']

        obs1, action1, next_obs1, reward1, mask1, drtg1, expert1, succ1 = \
            ptu.torchify(obs1, action1, next_obs1, reward1, mask1, drtg1, expert1, succ1)

        ssg_loss = self.update_ssg(obs1, action1, next_obs1, reward1, nowtimestep)

        with torch.no_grad():
            intr_reward = self.compute_intr_reward(obs1, action1, next_obs1, reward1)

        if nowtimestep % 1000 == 0:
            print("reward of this step is{}".format(reward))

        # =====================================ICM=============================================

        # Compute targets using bellman backup and target function
        with torch.no_grad():
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_obs)
            qf_list_next_target = self.critic_target(next_obs, next_state_action)
            min_qf_next_target = torch.min(torch.cat(qf_list_next_target, dim=1), dim=1)[0] \
                                    - self.alpha * next_state_log_pi.squeeze()
            min_qf_next_target = min_qf_next_target.squeeze()
            next_q_value = reward + mask * self.discount * min_qf_next_target
            # if nowtimestep % 1000 == 0:
            #     print("TD of this step is{}".format(next_q_value))
            # if nowtimestep % 1000 == 0:
            #     print("drtg of this step is{}".format(drtg))
            # Apply MCAC bonus
            if self.do_mcac_bonus:
                next_q_value = torch.max(next_q_value, drtg)
                # if nowtimestep % 1000 == 0:
                #     print("TD_MCAC of this step is{}".format(next_q_value))

        # Compute Q losses
        # Two Q-functions to mitigate positive bias in the policy improvement step
        qf_list = self.critic(obs, action)
        # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_losses = [
            F.mse_loss(qf.squeeze(), next_q_value)
            for qf in qf_list
        ]
        # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_loss = sum(qf_losses)

        # Q function backward pass
        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()

        # Sample from policy, compute minimum Q value of sampled action
        pi, log_pi, _ = self.policy.sample(obs)
        qf_list_pi = self.critic(obs, pi)
        min_qf_pi = torch.min(torch.cat(qf_list_pi, dim=1), dim=1)[0]

        # Calculate policy loss
        # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        # Policy backward pass
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        # Automatic entropy tuning
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(ptu.TORCH_DEVICE)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if self.total_it % self.target_update_interval == 0:
            ptu.soft_update(self.critic, self.critic_target, 1 - self.tau)

        info = {
            'policy_loss': policy_loss.item(),
            'alphpa_loss': alpha_loss.item(),
            'alpha_tlogs': alpha_tlogs.item()
        }
        for i, (qf, qf_loss) in enumerate(zip(qf_list, qf_losses)):
            if i > 3:
                break  # don't log absurd number of Q functions
            info['Q%d' % (i + 1)] = qf.mean().item()
            info['Q%d_loss' % (i + 1)] = qf_loss.item()

        self.total_it += 1
        return info, ssg_loss

    def save(self, folder):
        os.makedirs(folder, exist_ok=True)

        torch.save(self.critic.state_dict(), os.path.join(folder, "critic.pth"))
        torch.save(self.critic_optim.state_dict(), os.path.join(folder, "critic_optimizer.pth"))

        torch.save(self.policy.state_dict(), os.path.join(folder, "actor.pth"))
        torch.save(self.policy_optim.state_dict(), os.path.join(folder, "actor_optimizer.pth"))

    def load(self, folder):
        self.critic.load_state_dict(
            torch.load(os.path.join(folder, "critic.pth"), map_location=ptu.TORCH_DEVICE))
        self.critic_optim.load_state_dict(
            torch.load(os.path.join(folder, "critic_optimizer.pth"), map_location=ptu.TORCH_DEVICE))
        self.critic_target = copy.deepcopy(self.critic)

        self.policy.load_state_dict(
            torch.load(os.path.join(folder, "actor.pth"), map_location=ptu.TORCH_DEVICE))
        self.policy_optim.load_state_dict(
            torch.load(os.path.join(folder, "actor_optimizer.pth"), map_location=ptu.TORCH_DEVICE))
