import numpy as np

import torch
import torch.optim as optim
from torch import nn as nn

from rlkit.torch.networks import stochastic_actor2, critic
from rlkit.data_management.env_replay_buffer import SimpleReplayBuffer

class SRSoftActorCritic2():
    def __init__(self,
                 obs_dim,
                 action_dim,
                 net_size,
                 no_sr,
                 latent_action_dim,
                 device,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 sparse_rewards=False,
                 **kwargs):

        self.replay_buffer_size = 1000000
        self.batch_size = kwargs['batch_size']
        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']
        self.discount = kwargs['discount']

        self.sparse_rewards = sparse_rewards
        self.no_sr = no_sr
        self.device = device

        self.qf1 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.qf2 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf1 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf2 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf1.load_state_dict(self.qf1.state_dict())
        self.target_qf2.load_state_dict(self.qf2.state_dict())

        self.policy = stochastic_actor2(obs_dim,
                                       action_dim,
                                       net_size,
                                       latent_dim=latent_action_dim).to(self.device)

        self.policy_optimizer = optim.Adam(self.policy.parameters(),
                                                lr=policy_lr)
        self.qf_optimizer = optim.Adam(list(self.qf1.parameters())+list(self.qf2.parameters()),
                                             lr=qf_lr)

        self.target_entropy = -np.prod((action_dim,)).item()
        self.log_alpha = torch.tensor(-0.5, requires_grad=True,device=self.device)
        self.alpha = self.log_alpha.exp().item()
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=policy_lr)
        
        self.replay_buffer = SimpleReplayBuffer(self.replay_buffer_size,
                                                obs_dim,
                                                action_dim)
        self.meta_loss = []
        self.update_step = 0
        
    def load_shared_network(self,shared_network):
        "Load shared network across tasks"
        self.policy.shared_layer = shared_network

    def alpha_update(self,obs):
        with torch.no_grad():
            _, log_prob = self.policy.action(obs)
        alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.log_alpha, max_norm=0.1)
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp().item()


    def compute_meta_loss(self,shared_network):
        random_batch = self.replay_buffer.sample_batch(self.batch_size)
        obs = torch.Tensor(random_batch['observations']).to(self.device)
        latent_state = self.policy.detached_task_filtered_state(obs)
        latent_action = shared_network(latent_state)

        action, log_prob = self.policy.detached_task_filterd_action(latent_action)
        min_q = torch.min(self.target_qf1(obs, action),
                        self.target_qf2(obs, action))
        shared_loss = (self.alpha * log_prob - min_q).mean()
        self.meta_loss.append(shared_loss)

    def get_meta_loss(self):
        if not self.meta_loss:
            return None
        meta_loss = torch.mean(torch.stack(self.meta_loss,dim=0))
        self.meta_loss = []
        return meta_loss

    def collect_data_and_train_filter(self,env):
        o = env.reset()
        env_step = 0
        episode_return = 0
        self.update_step = 0
        while env_step<self.max_path_length:
            env_step += 1
            a, _ = self.policy.select_action(o)
            next_o, r, d, env_info = env.step(a)
            episode_return += r
            self.replay_buffer.add_sample(o,a,r,d,next_o,**{'env_info':env_info})
            if self.replay_buffer.size() > self.batch_size:
                self.update()
                self.update_step += 1
            o = next_o
            if d:
                break

    def collect_data_and_train_filter_mw(self,env,shared_network):
        obs, _ = env.reset()
        env_step = 0
        episode_return = 0
        self.update_step = 0
        truncated = False
        terminated = False
        while not (truncated or terminated):
            env_step += 1
            action, _ = self.policy.select_action(obs)
            next_obs, reward, terminated, truncated, info = env.step(action)
            episode_return += reward
            self.replay_buffer.add_sample(obs,action,reward,terminated,next_obs,**{'env_info':info})
            if self.replay_buffer.size()> self.batch_size:
                self.update()
                self.update_step += 1
            obs = next_obs
        return env_step
        

    def _min_q(self, obs, actions):
        q1 = self.qf1(obs, actions)
        q2 = self.qf2(obs, actions)
        min_q = torch.min(q1, q2)
        return min_q

    def soft_target_update(self, main, target, tau: float = 0.005):
        for main_param, target_param in zip(main.parameters(), target.parameters()):
            target_param.data.copy_(tau * main_param.data + (1.0 - tau) * target_param.data)

    def sample_batch(self):
        random_batch = self.replay_buffer.sample_batch(self.batch_size)
        obs = torch.Tensor(random_batch['observations']).to(self.device)
        actions = torch.Tensor(random_batch['actions']).to(self.device)
        rewards = torch.Tensor(random_batch['rewards']).to(self.device)
        terms = torch.Tensor(random_batch['terminals']).to(self.device)
        next_obs = torch.Tensor(random_batch['next_observations']).to(self.device)

        return obs, actions, rewards, terms, next_obs

    def update(self):
        obs, actions, rewards, terms, next_obs = self.sample_batch()
    
        with torch.no_grad():
            next_action, next_log_prob = self.policy.action(next_obs)
            next_q1 = self.target_qf1(next_obs,next_action)
            next_q2 = self.target_qf2(next_obs,next_action)
            min_next_q = torch.min(next_q1,next_q2) - self.alpha * next_log_prob
            target_q = self.reward_scale*rewards + self.discount*(1-terms)*min_next_q 

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        qf1_loss = (q1_pred-target_q).pow(2).mean()
        qf2_loss = (q2_pred-target_q).pow(2).mean()
        qf_loss = qf1_loss + qf2_loss
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()


        cur_actions, log_prob = self.policy.action(obs)
        min_q = torch.min(self.qf1(obs, cur_actions),
                          self.qf2(obs, cur_actions))
        policy_loss = (self.alpha * log_prob - min_q).mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.soft_target_update(self.qf1,self.target_qf1)
        self.soft_target_update(self.qf2,self.target_qf2)
        self.alpha_update(obs)

    def evaluation(self,env):
        with torch.no_grad():
            n_eval_epi = 3
            returns = 0
            for _ in range(n_eval_epi):
                episode_returns = 0
                env_step = 0
                o = env.reset()
                while env_step<self.max_path_length:
                    env_step += 1
                    a = self.policy.select_action(o,deterministic=True)
                    next_o, r, d, env_info = env.step(a)
                    episode_returns += r
                    o = next_o
                    if d:
                        break
                returns += episode_returns
        return returns/n_eval_epi

    def evaluation_mw(self,env):
        with torch.no_grad():
            n_eval_epi = 20
            n_success = 0
            total_reward = 0
            for _ in range(n_eval_epi):
                obs, _ = env.reset()
                terminated = False
                truncated = False
                success = False
                while not(truncated or terminated):
                    action = self.policy.select_action(obs,deterministic=True)
                    next_obs, reward, terminated, truncated, info = env.step(action)
                    total_reward += reward
                    obs = next_obs
                    if info['success']==1.0:
                        success = True
                n_success += float(success)
            average_success = n_success/n_eval_epi
            average_reward = total_reward/n_eval_epi
        return average_success, average_reward