import numpy as np

import torch
import torch.optim as optim
from torch import nn as nn

from rlkit.torch.networks import stochastic_actor2, critic, SVF_Policy
from rlkit.data_management.env_replay_buffer import SimpleReplayBuffer

class SDSoftActorCritic():
    def __init__(self,
                 obs_dim,
                 action_dim,
                 net_size,
                 latent_action_dim,
                 device,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 sparse_rewards=False,
                 **kwargs):

        self.replay_buffer_size = 1000000
        self.batch_size = kwargs['batch_size']
        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']
        self.discount = kwargs['discount']

        self.sparse_rewards = sparse_rewards
        self.device = device

        self.qf1 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.qf2 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf1 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf2 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf1.load_state_dict(self.qf1.state_dict())
        self.target_qf2.load_state_dict(self.qf2.state_dict())

        self.policy = stochastic_actor2(obs_dim,
                                       action_dim,
                                       net_size,
                                       latent_dim=latent_action_dim).to(self.device)
        self.policy.shared_layer.requires_grad_(True)

        self.svf_policy = SVF_Policy(obs_dim,
                                     action_dim,
                                     net_size,
                                     latent_action_dim).to(self.device)
        
        self.svf_optimizer = optim.Adam(self.svf_policy.parameters(),
                                        lr = policy_lr)

        self.policy_optimizer = optim.Adam(self.policy.parameters(),
                                           lr=policy_lr)
        self.qf_optimizer = optim.Adam(list(self.qf1.parameters())+list(self.qf2.parameters()),
                                       lr=qf_lr)

        self.target_entropy = -np.prod((action_dim,)).item()
        self.log_alpha = torch.tensor(-0.5, requires_grad=True,device=self.device)
        self.alpha = self.log_alpha.exp().item()
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=policy_lr)
        
        self.replay_buffer = SimpleReplayBuffer(self.replay_buffer_size,
                                                obs_dim,
                                                action_dim)
        self.meta_loss = []
        self.update_step = 0

    def save_backup_params(self):
        self.backup_params = {name: param.detach().clone()
                         for name, param in self.svf_policy.named_parameters()}
        
    def reset_params(self):
        for name, param in self.svf_policy.named_parameters():
            if name in self.backup_params:
                param.data.copy_(self.backup_params[name])

    def load_shared_network(self,shared_network):
        "Load shared network across tasks"
        self.policy.shared_layer = shared_network

    def alpha_update(self,obs):
        with torch.no_grad():
            _, log_prob = self.policy.action(obs)
        alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.log_alpha, max_norm=0.1)
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp().item()

    def collet_data_and_train_filter(self,env):
        o = env.reset()
        env_step = 0
        episode_return = 0
        self.update_step = 0
        while env_step<self.max_path_length:
            env_step += 1
            a, _ = self.policy.select_action(o)
            next_o, r, d, env_info = env.step(a)
            episode_return += r
            self.replay_buffer.add_sample(o,a,r,d,next_o,**{'env_info':env_info})
            if self.replay_buffer.size() > self.batch_size:
                self.update()
                self.update_step += 1
            o = next_o
            if d:
                break

    def collect_data_and_train_filter_mw(self,env,shared_network):
        obs, _ = env.reset()
        env_step = 0
        episode_return = 0
        self.update_step = 0
        truncated = False
        terminated = False
        while not (truncated or terminated):
            env_step += 1
            action, _ = self.policy.select_action(obs)
            next_obs, reward, terminated, truncated, info = env.step(action)
            episode_return += reward
            self.replay_buffer.add_sample(obs,action,reward,terminated,next_obs,**{'env_info':info})
            if self.replay_buffer.size()> self.batch_size:
                self.update(shared_network)
                self.update_step += 1
            obs = next_obs
        return env_step

    def _min_q(self, obs, actions):
        q1 = self.qf1(obs, actions)
        q2 = self.qf2(obs, actions)
        min_q = torch.min(q1, q2)
        return min_q

    def soft_target_update(self, main, target, tau: float = 0.005):
        for main_param, target_param in zip(main.parameters(), target.parameters()):
            target_param.data.copy_(tau * main_param.data + (1.0 - tau) * target_param.data)

    def update(self):
        random_batch = self.replay_buffer.sample_batch(self.batch_size)
        obs = torch.Tensor(random_batch['observations']).to(self.device)
        actions = torch.Tensor(random_batch['actions']).to(self.device)
        rewards = torch.Tensor(random_batch['rewards']).to(self.device)
        terms = torch.Tensor(random_batch['terminals']).to(self.device)
        next_obs = torch.Tensor(random_batch['next_observations']).to(self.device)
    
        with torch.no_grad():
            next_action, next_log_prob = self.policy.action(next_obs)
            next_q1 = self.target_qf1(next_obs,next_action)
            next_q2 = self.target_qf2(next_obs,next_action)
            min_next_q = torch.min(next_q1,next_q2) - self.alpha * next_log_prob
            target_q = self.reward_scale*rewards + self.discount*(1-terms)*min_next_q 

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        qf1_loss = (q1_pred-target_q).pow(2).mean()
        qf2_loss = (q2_pred-target_q).pow(2).mean()
        qf_loss = qf1_loss + qf2_loss
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()


        cur_actions, log_prob = self.policy.action(obs)
        min_q = torch.min(self.qf1(obs, cur_actions),
                          self.qf2(obs, cur_actions))
        policy_loss = (self.alpha * log_prob - min_q).mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.soft_target_update(self.qf1,self.target_qf1)
        self.soft_target_update(self.qf2,self.target_qf2)
        self.alpha_update(obs)

    def update_svf_policy(self,svf):
        if self.batch_size > self.replay_buffer.size():
            return None
        random_batch = self.replay_buffer.sample_batch(self.batch_size)
        obs = torch.Tensor(random_batch['observations']).to(self.device)
        self.svf_policy_filter_update(obs)
        flow_backbone_loss = self.get_svf_loss(svf,obs)
        return flow_backbone_loss    

    def svf_policy_filter_update(self,obs):
        actions = self.policy.action(obs,deterministic=True)
        pred_flow = self.svf_policy(obs,actions)
        flow = self.svf_policy.svf.d_flow(actions)
        flow_loss = nn.functional.mse_loss(pred_flow,flow)
        self.svf_optimizer.zero_grad()
        flow_loss.backward()
        self.svf_optimizer.step()

    def get_svf_loss(self,svf,obs):
        with torch.no_grad():
            actions = self.policy.action(obs,deterministic=True)
        t = torch.rand(actions.shape[0],1).to(self.device)
        x0 = torch.randn(actions.shape[0],actions.shape[1]).to(self.device)
        latent_states = self.svf_policy.detached_task_filtered_state(obs)
        xt = svf.flow(t,x0,actions)
        latent_actions = svf(t,xt,latent_states)
        pred_flow = self.svf_policy.detached_task_filtered_action(latent_actions)
        flow = svf.d_flow(actions)
        flow_loss = nn.functional.mse_loss(pred_flow,flow)
        return flow_loss  

    def svf_action(self,obs):
        obs = torch.Tensor(obs).to(next(self.svf_policy.parameters()).device)
        latent_state = self.svf_policy.detached_task_filtered_state(obs)
        svf_action = self.svf_policy.euler_integration(latent_state)
        
        return svf_action.cpu().numpy()
    
    def svf_eval(self,env):
        with torch.no_grad():
            n_eval_epi = 3
            returns = 0
            for _ in range(n_eval_epi):
                episode_returns = 0
                env_step = 0
                obs = env.reset()
                while env_step<self.max_path_length:
                    env_step += 1
                    action = self.svf_action(obs)
                    obs, reward, done, env_info = env.step(action)
                    episode_returns += reward
                    if done:
                        break
                returns += episode_returns
        return returns/n_eval_epi


    def evaluation(self,env):
        with torch.no_grad():
            n_eval_epi = 3
            returns = 0
            for _ in range(n_eval_epi):
                episode_returns = 0
                env_step = 0
                o = env.reset()
                while env_step<self.max_path_length:
                    env_step += 1
                    a = self.policy.select_action(o,deterministic=True)
                    next_o, r, d, env_info = env.step(a)
                    episode_returns += r
                    o = next_o
                    if d:
                        break
                returns += episode_returns
        return returns/n_eval_epi

    def evaluation_mw(self,env):
        with torch.no_grad():
            n_eval_epi = 20
            n_success = 0
            total_reward = 0
            for _ in range(n_eval_epi):
                obs, _ = env.reset()
                terminated = False
                truncated = False
                success = False
                while not(truncated or terminated):
                    action = self.policy.select_action(obs,deterministic=True)
                    next_obs, reward, terminated, truncated, info = env.step(action)
                    total_reward += reward
                    obs = next_obs
                    if info['success']==1.0:
                        success = True
                n_success += float(success)
            average_success = n_success/n_eval_epi
            average_reward = total_reward/n_eval_epi
        return average_success, average_reward