import os
import torch
import math
import torch.nn.functional as F
from torch.optim import Adam
from sac.utils import soft_update, hard_update
from sac.model import PolicyNetwork, PolicyNetworkNorm, QNetwork, EmbeddingNetwork
import numpy as np
from collections import defaultdict
import random


class SAC(object):
    def __init__(self, num_inputs, action_space, args):
        torch.autograd.set_detect_anomaly(True)
        self.gamma = args.gamma
        self.tau = args.tau
        self.num_inputs = num_inputs
        self.num_outputs = args.emb_size
        self.policy_noise = args.target_noise
        self.noise_clip = args.noise_clip
        self.num_actions = action_space.shape[0]
        self.exploration_noise = args.act_noise
        self.max_reward = 0
        self.per = args.per
        self.epsilon = args.epsilon

        self.policy_type = args.policy
        self.target_update_interval = args.target_update_interval
        self.automatic_entropy_tuning = args.automatic_entropy_tuning

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.embedder = EmbeddingNetwork(num_inputs, args.hidden_size, args.emb_size)
        self.load_embedder(args.emb_name)

        self.critic = QNetwork(args.emb_size, self.num_actions, args.hidden_size).to(device=self.device)

        self.critic_target = QNetwork(args.emb_size, self.num_actions, args.hidden_size).to(self.device)
        hard_update(self.critic_target, self.critic)

        self.policy_losses = []
        self.critic1_losses = []
        self.critic2_losses = []
        self.policy_qvals = []
        self.entropy_losses = []
        self.qvals = []
        self.target_qvals = []

        if args.action_norm:
            self.policy = PolicyNetworkNorm(args.emb_size, self.num_actions, args.hidden_size).to(self.device)
            self.policy_target = PolicyNetworkNorm(args.emb_size, self.num_actions, args.hidden_size).to(
                self.device)
        else:
            self.policy = PolicyNetwork(args.emb_size, self.num_actions, args.hidden_size).to(self.device)
            self.policy_target = PolicyNetwork(args.emb_size, self.num_actions, args.hidden_size).to(
                self.device)


        hard_update(self.policy_target, self.policy)

        if args.train_emb:
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
            self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

        else:
            self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
            self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)

    def select_action(self, state, eval=False, noise=None):
        if noise == None:
            noise = self.exploration_noise

        if not eval:
            if random.random() <= self.epsilon:
                action = np.expand_dims(np.random.uniform(low=-1, high=1, size=self.num_actions), axis=0)
                return action

        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        action = self.policy(state).detach().cpu().numpy()[0]
        if eval == False:
            noise_size = self.num_actions
            if action.ndim > 1:
                noise_size = (len(action), noise_size)
            noise = np.random.normal(0, noise, noise_size)
            noisy_action = (action + noise).clip(-1, 1)
            return noisy_action
        return action

    def qloss(self, q_val, next_q_val, weights):
        with torch.no_grad():
            per_error = torch.abs(q_val - next_q_val)

        mse_loss = F.mse_loss(q_val, next_q_val, reduction='none')
        loss = (weights * mse_loss).mean()
        return loss, per_error

    def update_parameters(self, memory, updates, buffer):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch, idx_batch, weights_batch = memory

        states = np.stack(state_batch)
        actions = np.stack(action_batch)
        #print('reward: {}'.format(reward_batch))
        rewards = np.expand_dims(reward_batch, axis=1).astype(float)
        next_states = np.stack(next_state_batch)
        done = np.expand_dims(mask_batch, axis=1).astype(float)
        idx = np.expand_dims(idx_batch, axis=1).astype(float)
        weights = np.expand_dims(weights_batch, axis=1).astype(float)

        state_batch = torch.from_numpy(states).float().to(self.device)
        action_batch = torch.from_numpy(actions).float().to(self.device)
        reward_batch = torch.from_numpy(rewards).float().to(self.device)
        next_state_batch = torch.from_numpy(next_states).float().to(self.device)
        done_batch = torch.from_numpy(done).float().to(self.device)
        idx_batch = torch.from_numpy(idx).float().to(self.device)
        weights_batch = torch.from_numpy(weights).float().to(self.device)

        with torch.no_grad():
            noise = torch.randn_like(action_batch) * self.policy_noise
            clipped_noise = torch.clamp(noise, -self.noise_clip, self.noise_clip)
            next_state_action = self.policy_target(next_state_batch)
            noisy_next_action = torch.clamp((next_state_action + clipped_noise), -1, 1)

            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, noisy_next_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
            next_q_value = torch.max(reward_batch, (self.gamma * min_qf_next_target * (1-done_batch)))
            self.max_reward = max(np.max(reward_batch.detach().numpy()), self.max_reward)
            q_thresh = torch.full_like(next_q_value, (self.max_reward * 1.01))
            next_q_value = torch.min(next_q_value, q_thresh)

        qf1, qf2 = self.critic(state_batch, action_batch)

        qf1_loss, per_error = self.qloss(qf1, next_q_value, weights_batch)
        qf2_loss, _ = self.qloss(qf2, next_q_value, weights_batch)

        self.critic_optim.zero_grad()
        (qf1_loss + qf2_loss).backward()
        self.critic_optim.step()

        #update PER weights
        for i in range(len(idx_batch)):
            if not (math.isnan(idx_batch[i].item())):
                buffer.update(int(idx_batch[i].item()), per_error[i].item())

        if updates % self.target_update_interval == 0:
            policy_actions = self.policy(state_batch)
            policy_loss = -self.critic.Q1(state_batch, policy_actions).mean()

            self.policy_optim.zero_grad()
            policy_loss.backward()
            self.policy_optim.step()

            soft_update(self.critic_target, self.critic, self.tau)
            soft_update(self.policy_target, self.policy, self.tau)
            for i in range(self.target_update_interval):
                self.policy_losses.append(policy_loss.detach().numpy())

        self.critic1_losses.append(qf1_loss.detach().numpy())
        self.critic2_losses.append(qf2_loss.detach().numpy())
        self.qvals.append(qf1.detach().numpy().mean())
        self.target_qvals.append(next_q_value.detach().numpy().mean())


    # Save model parameters
    def save_model(self, env_name, suffix="", actor_path=None, critic_path=None):
        if not os.path.exists('models/'):
            os.makedirs('models/')

        if actor_path is None:
            actor_path = "models/sac_actor_{}_{}".format(env_name, suffix)
        if critic_path is None:
            critic_path = "models/sac_critic_{}_{}".format(env_name, suffix)
        print('Saving models to {} and {}'.format(actor_path, critic_path))
        torch.save(self.policy.state_dict(), actor_path)
        torch.save(self.critic.state_dict(), critic_path)


    # Load model parameters
    def load_model(self, actor_path=None, critic_path=None):
        print('Loading models from {} and {}'.format(actor_path, critic_path))
        if actor_path is not None:
            self.policy.load_state_dict(torch.load(actor_path))
            hard_update(self.policy_target, self.policy)
        if critic_path is not None:
            self.critic.load_state_dict(torch.load(critic_path))
            hard_update(self.critic_target, self.critic)

    def load_embedder(self, embedder_name=None):
        if embedder_name is not None:
            self.embedder.load_state_dict(torch.load('models/{}'.format(embedder_name)))

