"""
TD3 implementation based on https://github.com/sfujim/TD3
"""

import mcac.algos.core as core
import mcac.utils.pytorch_utils as ptu

import torch
import torch.nn.functional as F

import copy
import os

import numpy as np


class TD3_SIG:
    def __init__(self, params, ssg_module):
        self.l0 = 0.0
        self.max_action = params['max_action']
        self.discount = params['discount']
        self.tau = params['tau']
        self.policy_noise = params['policy_noise']
        self.noise_clip = params['noise_clip']
        self.policy_freq = params['policy_freq']
        self.batch_size = params['batch_size']
        self.batch_size_demonstrator = params['batch_size_demonstrator']
        self.do_bc_loss = params['do_bc_loss']
        self.bc_weight = params['bc_weight']
        self.bc_decay = params['bc_decay']
        self.do_q_filter = params['do_q_filter']
        self.do_mcac_bonus = params['do_mcac_bonus']
        self.n_demos = params['n_demos']

        self.total_it = 0
        self.running_risk = 1

        self.actor = core.Actor(params['d_obs'], params['d_act'], self.max_action).to(ptu.TORCH_DEVICE)
        self.actor_target = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=params['lr_actor'])

        self.critic = core.Critic(params['d_obs'], params['d_act']).to(ptu.TORCH_DEVICE)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=params['lr_critic'])

        # 关于ICM的设置
        self.ssg_scale = 0.3
        self.ssg = ssg_module
        # optimizers
        self.ssg_opt = torch.optim.Adam(self.ssg.parameters(), lr=params['lr'])
        self.ssg.train()

    def update_ssg(self, obs, action, next_obs, reward, nowtimestep):

        forward_error, backward_error, reward_error = self.ssg(obs, action, next_obs, reward)

        forward_loss = (torch.mean(forward_error)).item()
        if nowtimestep % 1000 == 0:
            print("=======================================================================")
            print("forward_loss is {}".format(forward_loss))

        backward_loss = (torch.mean(backward_error)).item()
        if nowtimestep % 1000 == 0:
            print("backward_loss is {}".format(backward_loss))

        reward_loss = (torch.mean(reward_error)).item()
        if nowtimestep % 1000 == 0:
            print("reward_loss is {}".format(reward_loss))

        loss = forward_error.mean() + backward_error.mean() + reward_error.mean()

        self.ssg_opt.zero_grad()
        loss.backward()
        self.ssg_opt.step()

        return loss

    def compute_intr_reward(self, obs, action, next_obs, reward_useless):
        forward_error, _, _ = self.ssg(obs, action, next_obs, reward_useless)
        reward = forward_error * self.ssg_scale
        reward = torch.log(reward + 1.0)
        return reward

    def select_action(self, state, evaluate=False):
        state = torch.FloatTensor(state.reshape(1, -1)).to(ptu.TORCH_DEVICE)
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, replay_buffer, replay_buffer_gen, nowtimestep, loss_record):
        # Sample from replay buffer
        # out_dict = replay_buffer.sample(self.batch_size)

        if nowtimestep == 0:
            self.l0 = loss_record
            rb2_ratio = 0
            rb1_ratio = 1
        else:
            x1 = self.l0
            x2 = 0.0
            y1 = 0
            y2 = 0.25
            # self
            a = (y1 - y2) / (x1 - x2)
            b = -x1 * a
            rb2_ratio = a * loss_record + b
            if rb2_ratio > y2:
                rb2_ratio = y2
            elif rb2_ratio < 0:
                rb2_ratio = 0
            rb1_ratio = 1 - rb2_ratio

        if nowtimestep % 1000 == 0:
            print("rb1 len is {}".format(len(replay_buffer)))
            print("rb2_img len is {}".format(len(replay_buffer_gen)))
        #print("use dream")

        out_dict1 = replay_buffer.sample(int(self.batch_size * rb1_ratio))
        #out_dict2 = replay_buffer_gen.sample(1)
        #print(int(self.batch_size * rb2_ratio))
        if int(self.batch_size * rb2_ratio) == 0:
            out_dict2 = {}
        else:
            # print("=======================================")
            # print(loss_record)

            # print("=======================================")
            out_dict2 = replay_buffer_gen.sample(int(self.batch_size * rb2_ratio))

        if out_dict2 == {}:
            out_dict = out_dict1
        else:

            batch3 = {}

            for key in out_dict1:
                if isinstance(out_dict1[key], np.ndarray):

                    batch3[key] = np.concatenate([out_dict1[key], out_dict2[key]])
                else:

                    batch3[key] = out_dict1[key] + out_dict2[key]
            out_dict = batch3


        obs, action, next_obs, reward, mask, drtg, expert = out_dict['obs'], out_dict['act'], \
                                                            out_dict['next_obs'], out_dict['rew'], \
                                                            out_dict['mask'], out_dict['drtg'], \
                                                            out_dict['expert']
        obs, action, next_obs, reward, mask, drtg, expert = \
            ptu.torchify(obs, action, next_obs, reward, mask, drtg, expert)

        # =====================================ICM============================================

        obs1, action1, next_obs1, reward1, \
            mask1, drtg1, expert1, succ1 = out_dict1['obs'], out_dict1['act'], \
            out_dict1['next_obs'], out_dict1['rew'], \
            out_dict1['mask'], out_dict1['drtg'], \
            out_dict1['expert'], out_dict1['succ']

        obs1, action1, next_obs1, reward1, mask1, drtg1, expert1, succ1 = \
            ptu.torchify(obs1, action1, next_obs1, reward1, mask1, drtg1, expert1, succ1)

        ssg_loss = self.update_ssg(obs1, action1, next_obs1, reward1, nowtimestep)

        if nowtimestep % 1000 == 0:
            print("reward of this step is{}".format(reward))

        # =====================================ICM=============================================

        info = {}
        with torch.no_grad():
            # Select action according to policy and add clipped noise
            noise = (torch.randn_like(action) * self.policy_noise)\
                .clamp(-self.noise_clip, self.noise_clip)

            next_action = (self.actor_target(next_obs) + noise)\
                .clamp(-self.max_action, self.max_action)

            # Compute the target Q value
            target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
            target_Q = torch.min(target_Q1, target_Q2).squeeze()
            target_Q = reward + mask * self.discount * target_Q
            target_Q = target_Q.squeeze()

            # Apply MCAC bonus
            if self.do_mcac_bonus:
                target_Q = torch.max(target_Q, drtg)

        # Get current Q estimates
        current_Q1, current_Q2 = self.critic(obs, action)
        current_Q1, current_Q2 = current_Q1.squeeze(), current_Q2.squeeze()

        # Compute critic loss
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        # Optimize the critic
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        info['critic_loss'] = critic_loss.item()
        info['Q1'] = current_Q1.mean().item()
        info['Q2'] = current_Q2.mean().item()

        # Delayed policy updates
        if self.total_it % self.policy_freq == 0:

            # Compute actor losse
            actor_q_loss = -self.critic.Q1(obs, self.actor(obs)).mean()

            # Behavior cloning auxiliary loss, inspired by DDPGfD paper
            if self.do_bc_loss:
                # Sample expert actions from the replay buffer
                out_dict = replay_buffer.sample_positive(self.batch_size_demonstrator, self.n_demos)
                obs, action = out_dict['obs'], out_dict['act']
                obs, action = ptu.torchify(obs, action)

                # Calculate loss as negative log prob of actions
                act_hat = self.actor(obs)
                losses = F.mse_loss(act_hat, action, reduction='none')

                # Optional Q filter
                if self.do_q_filter:
                    with torch.no_grad():
                        q_agent = self.critic.Q1(obs, act_hat)
                        q_expert = self.critic.Q1(obs, action)
                        q_filter = torch.gt(q_expert, q_agent).float()

                    if torch.sum(q_filter) > 0:
                        bc_loss = torch.sum(losses * q_filter) / torch.sum(q_filter)
                    else:
                        bc_loss = actor_q_loss * 0
                else:
                    bc_loss = torch.mean(losses)

            else:
                bc_loss = actor_q_loss * 0

            lambda_bc = self.bc_decay ** self.total_it * self.bc_weight
            actor_loss = actor_q_loss + lambda_bc * bc_loss

            # Optimize the actor
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            info['actor_loss'] = actor_loss.item()
            info['actor_q_loss'] = actor_q_loss.item()
            info['actor_bc_loss'] = bc_loss.item()

            # Update the frozen target models
            ptu.soft_update(self.critic, self.critic_target, 1 - self.tau)
            ptu.soft_update(self.actor, self.actor_target, 1 - self.tau)

        self.total_it += 1
        return info, ssg_loss

    def save(self, folder):
        os.makedirs(folder, exist_ok=True)

        torch.save(self.critic.state_dict(), os.path.join(folder, "critic.pth"))
        torch.save(self.critic_optimizer.state_dict(), os.path.join(folder, "critic_optimizer.pth"))

        torch.save(self.actor.state_dict(), os.path.join(folder, "actor.pth"))
        torch.save(self.actor_optimizer.state_dict(), os.path.join(folder, "actor_optimizer.pth"))

    def load(self, folder):
        self.critic.load_state_dict(torch.load(os.path.join(folder, "critic.pth")))
        self.critic_optimizer.load_state_dict(
            torch.load(os.path.join(folder, "critic_optimizer.pth")))
        self.critic_target = copy.deepcopy(self.critic)

        self.actor.load_state_dict(torch.load(os.path.join(folder, "actor.pth")))
        self.actor_optimizer.load_state_dict(
            torch.load(os.path.join(folder, "actor_optimizer.pth")))

