import os
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
from networks import GaussActor, Actor, Critic, Potential
from utils import log_prob_func, NegAbs, SampledValueBaseline
from torch.optim.lr_scheduler import CosineAnnealingLR

device = 'cuda'

class XMRL(object):
    def __init__(
        self,
        state_dim,
        action_dim,
        hidden_dim,
        batch_size,
        max_action,
        max_steps=1e6,
        dist_policy= False,
        discount=0.99,
        tau=0.005,
        policy_noise=0.2,
        noise_clip=0.5,
        policy_freq=2,
        potential_freq = 10,
        alpha=2.5,
    ):
        
        self.beta = GaussActor(state_dim, hidden_dim, action_dim).to(device)
        self.policy = GaussActor(state_dim, hidden_dim, action_dim).to(device)
        self.critic = Critic(state_dim, hidden_dim, action_dim).to(device)
        self.potential = Potential(action_dim, hidden_dim).to(device)
        self.critic_target = copy.deepcopy(self.critic)
        self.value = SampledValueBaseline(device, 3)
            
        self.beta_optimizer = torch.optim.Adam(self.beta.parameters(), lr=1e-4)
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=1e-5)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=1e-5)
        self.potential_optimizer = torch.optim.Adam(self.potential.parameters(), lr=1e-5)
        self.policy_lr_schedule = CosineAnnealingLR(self.policy_optimizer, max_steps)
        
        self.max_action = max_action
        self.discount = discount
        self.tau = tau
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_freq = policy_freq
        self.potential_freq = potential_freq
        self.alpha = float(alpha)
        self.batch_size = batch_size
        self.W = 8
        self.temp = 10
        self.max_weight = 100.0
        
    def update_beta(self, replay_buffer):
        log_dict = {}
        state, action, _, _, _ = replay_buffer.sample(self.batch_size)
        
        b_dist = self.beta.get_dist(state)
        log_prob = log_prob_func(b_dist, action) 
        bc_loss = (-log_prob).mean()
        
        self.beta_optimizer.zero_grad()
        bc_loss.backward()
        self.beta_optimizer.step()
        log_dict["bc_loss"] = bc_loss.item()

        return log_dict
    
    def update_critic(self, replay_buffer, t):
        log_dict = {}
        state, action, next_state, reward, not_done = replay_buffer.sample(self.batch_size)
                    
        with torch.no_grad():
            
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.beta(next_state) + noise).clamp(-self.max_action, self.max_action)

            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q

        current_Q1, current_Q2 = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        log_dict["critic_loss"] = critic_loss.item()
        
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            
        return log_dict
    
    def conservative_update_critic(self, replay_buffer, t):
        log_dict = {}
        state, action, next_state, reward, not_done = replay_buffer.sample(self.batch_size)
                    
        with torch.no_grad():
            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.policy(next_state) + noise).clamp(-self.max_action, self.max_action) #this is also works
            #next_action = (self.beta(next_state) + noise).clamp(-self.max_action, self.max_action)#alternative

            target_Q1, target_Q2 = self.critic_target(next_state, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + not_done * self.discount * target_Q
            
            pi = self.policy(state)
            
        current_Q1, current_Q2 = self.critic(state, action)
        pi_Q1, pi_Q2 = self.critic(state, pi)
        
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        conservative_loss = pi_Q1.mean()+pi_Q2.mean()
        total_loss = conservative_loss + critic_loss
        
        self.critic_optimizer.zero_grad()
        total_loss.backward()
        self.critic_optimizer.step()
        log_dict["critic_loss"] = critic_loss.item()
        log_dict["conservative_loss"] = conservative_loss.item()
        
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        return log_dict
    
    
    def conservative_update_transport(self, replay_buffer, step):
        log_dict = {}
        state, action, _, _, _ = replay_buffer.sample(self.batch_size)

        dist = self.policy.get_dist(state)
        pi = dist.rsample()
        Q = self.critic.Q1(state, pi)
        f_policy = self.potential(pi)

        policy_loss = (-Q).mean() - f_policy.mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        self.policy_lr_schedule.step()
        log_dict["policy_loss"] = policy_loss.item()

        return log_dict
    
    def update_potential(self, replay_buffer, step):
        log_dict = {}
        state, action, _, _, _ = replay_buffer.sample(self.batch_size)

        with torch.no_grad():
            pi = self.policy(state)

        f_policy = self.potential(pi)
        f_data = self.potential(action)

        potential_loss = f_policy.mean() - (self.W*f_data).mean()

        self.potential_optimizer.zero_grad()
        potential_loss.backward()
        self.potential_optimizer.step()
        log_dict["potential_loss"] = potential_loss.item()
            
        return log_dict