"""
Build on top of SAC implementation from https://github.com/pranz24/pytorch-soft-actor-critic/blob/master/sac.py
"""

import mcac.algos.core as core
import mcac.utils.pytorch_utils as ptu

import torch
import torch.nn.functional as F
from torch.optim import Adam

import copy
import os

import numpy as np


class GQE_SIG:
    def __init__(self, params, ssg_module):

        self.l0 = 0.0
        
        self.tau = params['tau']
        self.alpha = params['alpha']
        self.max_action = params['max_action']
        self.discount = params['discount']
        self.batch_size = params['batch_size']
        self.do_mcac_bonus = params['do_mcac_bonus']
        self.gqe_lambda = params['gqe_lambda']
        self.gqe_n = params['gqe_n']
        self.total_it = 0
        self.running_risk = 1

        self.policy_type = params['policy']
        self.target_update_interval = params['target_update_interval']
        self.automatic_entropy_tuning = params['automatic_entropy_tuning']

        self.critic = core.Critic(params['d_obs'], params['d_act'],
                                  ensemble_size=params['q_ensemble_size']).to(ptu.TORCH_DEVICE)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optim = torch.optim.Adam(self.critic.parameters(), lr=params['lr'])

        if self.policy_type == "Gaussian":
            # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
            if self.automatic_entropy_tuning is True:
                self.target_entropy = -torch.prod(torch.Tensor(params['d_act'])
                                                  .to(ptu.TORCH_DEVICE)).item()
                self.log_alpha = torch.zeros(1, requires_grad=True, device=ptu.TORCH_DEVICE)
                self.alpha_optim = Adam([self.log_alpha], lr=params['lr'])

            self.policy = core.GaussianPolicy(params['d_obs'], params['d_act'],
                                              params['hidden_size'], params['max_action']) \
                .to(ptu.TORCH_DEVICE)
            self.policy_optim = Adam(self.policy.parameters(), lr=params['lr'])

        else:
            self.alpha = 0
            self.automatic_entropy_tuning = False
            self.policy = core.DeterministicPolicy(params['d_obs'], params['d_act'],
                                                   params['hidden_size'], params['max_action']) \
                .to(ptu.TORCH_DEVICE)
            self.policy_optim = Adam(self.policy.parameters(), lr=params['lr'])

        # 关于ICM的设置
        self.ssg_scale = 0.3
        self.ssg = ssg_module
        # optimizers
        self.ssg_opt = torch.optim.Adam(self.ssg.parameters(), lr=params['lr'])
        self.ssg.train()

    def update_ssg(self, obs, action, next_obs, reward, nowtimestep):

        forward_error, backward_error, reward_error = self.ssg(obs, action, next_obs, reward)

        forward_loss = (torch.mean(forward_error)).item()
        if nowtimestep % 1000 == 0:
            print("=======================================================================")
            print("forward_loss is {}".format(forward_loss))

        backward_loss = (torch.mean(backward_error)).item()
        if nowtimestep % 1000 == 0:
            print("backward_loss is {}".format(backward_loss))

        reward_loss = (torch.mean(reward_error)).item()
        if nowtimestep % 1000 == 0:
            print("reward_loss is {}".format(reward_loss))

        loss = forward_error.mean() + backward_error.mean() + reward_error.mean()

        self.ssg_opt.zero_grad()
        loss.backward()
        self.ssg_opt.step()

        return loss

    def compute_intr_reward(self, obs, action, next_obs, reward_useless):
        forward_error, _, _ = self.ssg(obs, action, next_obs, reward_useless)
        reward = forward_error * self.ssg_scale
        reward = torch.log(reward + 1.0)
        return reward


    def select_action(self, state, evaluate=False):
        state = ptu.torchify(state).unsqueeze(0)
        if evaluate is False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0] * self.max_action

    def select_action_batch(self, states, evaluate=False):
        states = ptu.torchify(states)
        if evaluate is False:
            action, _, _ = self.policy.sample(states)
        else:
            _, _, action = self.policy.sample(states)
        return action.detach() * self.max_action

    def update(self, replay_buffer, replay_buffer_gen, nowtimestep, loss_record, init=False):
        #out_dict = replay_buffer.sample_chunk(self.batch_size, self.gqe_n)
        
        if nowtimestep == 0:
            self.l0 = loss_record
            rb2_ratio = 0
            rb1_ratio = 1
        else:
            x1 = self.l0
            x2 = 0.0
            y1 = 0
            y2 = 0.25
            # self
            a = (y1 - y2) / (x1 - x2)
            b = -x1 * a
            rb2_ratio = a * loss_record + b
            if rb2_ratio > y2:
                rb2_ratio = y2
            elif rb2_ratio < 0:
                rb2_ratio = 0
            rb1_ratio = 1 - rb2_ratio

        if nowtimestep % 1000 == 0:
            print("rb1 len is {}".format(len(replay_buffer)))
            print("rb2_img len is {}".format(len(replay_buffer_gen)))
        # print("use dream")
        # print(len(replay_buffer_gen))
        # out_dict1 = replay_buffer.sample(1)
        out_dict1 = replay_buffer.sample_chunk(int(self.batch_size * rb1_ratio), self.gqe_n)
        # out_dict2 = replay_buffer_gen.sample(1)
        # print(int(self.batch_size * rb2_ratio))
        if int(self.batch_size * rb2_ratio) == 0:
            out_dict2 = {}
        else:
            # print("=======================================")
            # print(loss_record)
            # print(rb2_ratio)
            # print(int(self.batch_size * rb2_ratio))
            # print("=======================================")
            out_dict2 = replay_buffer_gen.sample_chunk(int(self.batch_size * rb2_ratio), self.gqe_n)

        if out_dict2 == {}:
            out_dict = out_dict1
        else:

            batch3 = {}

            for key in out_dict1:
                if isinstance(out_dict1[key], np.ndarray):

                    batch3[key] = np.concatenate([out_dict1[key], out_dict2[key]])
                else:

                    batch3[key] = out_dict1[key] + out_dict2[key]
            out_dict = batch3

        obs_chunk, action_chunk, next_obs_chunk, reward_chunk, \
            mask_chunk, drtg_chunk = out_dict['obs'], out_dict['act'], \
                                       out_dict['next_obs'], out_dict['rew'], \
                                       out_dict['mask'], out_dict['drtg']
        obs_chunk, action_chunk, next_obs_chunk, reward_chunk, mask_chunk, drtg_chunk = \
            ptu.torchify(obs_chunk, action_chunk, next_obs_chunk, reward_chunk, mask_chunk, drtg_chunk)


        obs1, action1, next_obs1, reward1, \
            mask1, drtg1, expert1, succ1 = out_dict1['obs'], out_dict1['act'], \
            out_dict1['next_obs'], out_dict1['rew'], \
            out_dict1['mask'], out_dict1['drtg'], \
            out_dict1['expert'], out_dict1['succ']

        obs1, action1, next_obs1, reward1, mask1, drtg1, expert1, succ1 = \
            ptu.torchify(obs1, action1, next_obs1, reward1, mask1, drtg1, expert1, succ1)

        ssg_loss = self.update_ssg(obs1, action1, next_obs1, reward1, nowtimestep)

        if nowtimestep % 1000 == 0:
            print("reward of this step is{}".format(reward_chunk))


        with torch.no_grad():
            next_obs_chunk = next_obs_chunk.reshape((next_obs_chunk.shape[0]*next_obs_chunk.shape[1], -1))
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_obs_chunk)
            qf_list_next_target = self.critic_target(next_obs_chunk, next_state_action)
            min_qf_next_target = torch.min(torch.cat(qf_list_next_target, dim=1), dim=1)[0] \
                                    - self.alpha * next_state_log_pi.squeeze()
            min_qf_next_target = min_qf_next_target.reshape(reward_chunk.shape)

            # Construct `multiplier`, which multiplies each reward value to calculate necessary
            # finite geometric sums
            totals = torch.sum(mask_chunk, dim=1, keepdim=True)
            totals = totals.repeat((1, self.gqe_n))

            #print(mask_chunk.shape)
            # mask_shifted = torch.cat((torch.ones((self.batch_size, 1), device=ptu.TORCH_DEVICE),
            #                           mask_chunk[:, :-1]), dim=1)
            mask_shifted = torch.cat((torch.ones((int(self.batch_size * rb1_ratio)+int(self.batch_size * rb2_ratio), 1), device=ptu.TORCH_DEVICE),
                                      mask_chunk[:, :-1]), dim=1)

            totals_shifted = torch.sum(mask_shifted, dim=1, keepdim=True)
            totals_shifted = totals_shifted.repeat((1, self.gqe_n))

            # arange = torch.arange(self.gqe_n, device=ptu.TORCH_DEVICE)\
            #     .repeat((self.batch_size, 1))\
            #     .reshape((self.batch_size, self.gqe_n))

            arange = torch.arange(self.gqe_n, device=ptu.TORCH_DEVICE) \
                .repeat(((int(self.batch_size * rb1_ratio)+int(self.batch_size * rb2_ratio)), 1)) \
                .reshape(((int(self.batch_size * rb1_ratio)+int(self.batch_size * rb2_ratio)), self.gqe_n))

            multiplier = torch.pow((self.gqe_lambda * self.discount), arange) \
                         * (1 - torch.pow(self.gqe_lambda, totals_shifted - arange) + 1e-8) \
                         / (1 - torch.pow(self.gqe_lambda, totals_shifted) + 1e-8)
            r_mult = reward_chunk * multiplier * mask_shifted
            q_mult = torch.pow(self.discount * self.gqe_lambda, arange + 1) * min_qf_next_target
            q_divisor = self.gqe_lambda * (1 - torch.pow(self.gqe_lambda, totals) + 1e-8) \
                        / (1 - self.gqe_lambda)
            everything = (r_mult + q_mult / q_divisor * mask_chunk)

            next_q_value = torch.sum(everything, dim=1)

            # Apply MCAC bonus
            if self.do_mcac_bonus:
                drtg_chunk = drtg_chunk[:, 0]
                next_q_value = torch.max(next_q_value, drtg_chunk)

        obs = obs_chunk[:, 0]
        action = action_chunk[:, 0]

        # Compute Q losses
        # Two Q-functions to mitigate positive bias in the policy improvement step
        qf_list = self.critic(obs, action)
        # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_losses = [
            F.mse_loss(qf.squeeze(), next_q_value)
            for qf in qf_list
        ]
        # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf_loss = sum(qf_losses)

        # Q function backward pass
        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()

        # Sample from policy, compute minimum Q value of sampled action
        pi, log_pi, _ = self.policy.sample(obs)
        qf_list_pi = self.critic(obs, pi)
        min_qf_pi = torch.min(torch.cat(qf_list_pi, dim=1), dim=1)[0]

        # Calculate policy loss
        # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        # Policy backward pass
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        # Automatic entropy tuning
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
            alpha_tlogs = self.alpha.clone()  # For TensorboardX logs
        else:
            alpha_loss = torch.tensor(0.).to(ptu.TORCH_DEVICE)
            alpha_tlogs = torch.tensor(self.alpha)  # For TensorboardX logs

        if self.total_it % self.target_update_interval == 0:
            ptu.soft_update(self.critic, self.critic_target, 1 - self.tau)

        info = {
            'policy_loss': policy_loss.item(),
            'alphpa_loss': alpha_loss.item(),
            'alpha_tlogs': alpha_tlogs.item()
        }
        for i, (qf, qf_loss) in enumerate(zip(qf_list, qf_losses)):
            if i > 3:
                break  # don't log absurd number of Q functions
            info['Q%d' % (i + 1)] = qf.mean().item()
            info['Q%d_loss' % (i + 1)] = qf_loss.item()

        self.total_it += 1
        return info, ssg_loss

    def save(self, folder):
        os.makedirs(folder, exist_ok=True)

        torch.save(self.critic.state_dict(), os.path.join(folder, "critic.pth"))
        torch.save(self.critic_optim.state_dict(), os.path.join(folder, "critic_optimizer.pth"))

        torch.save(self.policy.state_dict(), os.path.join(folder, "actor.pth"))
        torch.save(self.policy_optim.state_dict(), os.path.join(folder, "actor_optimizer.pth"))

    def load(self, folder):
        self.critic.load_state_dict(
            torch.load(os.path.join(folder, "critic.pth"), map_location=ptu.TORCH_DEVICE))
        self.critic_optim.load_state_dict(
            torch.load(os.path.join(folder, "critic_optimizer.pth"), map_location=ptu.TORCH_DEVICE))
        self.critic_target = copy.deepcopy(self.critic)

        self.policy.load_state_dict(
            torch.load(os.path.join(folder, "actor.pth"), map_location=ptu.TORCH_DEVICE))
        self.policy_optim.load_state_dict(
            torch.load(os.path.join(folder, "actor_optimizer.pth"), map_location=ptu.TORCH_DEVICE))
