import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from copy import deepcopy
import numpy as np
import os

from utils.vectorised_networks import ContinuousVectorisedCritic
from utils.misc import soft_update
from utils.base_agent import ContinuousBaseAgent, BaseActorCritic

from torch.distributions import Normal


class Actor(nn.Module):

    def __init__(self,obs_dims, action_dims, log_std_min=-20, log_std_max=2, hidden_dim=256, **kwargs):

        super().__init__()
        self.fc1 = nn.Linear(obs_dims,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.mean = nn.Linear(hidden_dim, action_dims)
        self.log_std = nn.Linear(hidden_dim, action_dims)

        self.max_val = torch.tensor(kwargs.get('max_val',np.inf))
        self.min_val = torch.tensor(kwargs.get('min_val',-np.inf))
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

    def forward(self, state):

        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))

        mean = self.mean(x)
        log_std = self.log_std(x)
        log_std = torch.tanh(log_std)
        
        log_std = self.log_std_min + 0.5*(self.log_std_max - self.log_std_min)*(log_std + 1)
        std = log_std.exp()

        
            
        return mean, log_std

    def sample(self, state, epsilon=1e-6, reparameterise=True, deterministic=False, **kwargs):
        

        mean, log_std = self(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t 
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log((1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(-1, keepdim=True)
        mean = torch.tanh(mean) 

        if deterministic:
            return {'action':mean}
        else:
            action_info = {'action':action,'log_prob':log_prob}
            return action_info



class Agent(BaseActorCritic):

    def __init__(self,obs_dims,action_dims,actor_lr,critic_lr,gamma,
            tau,mem_size,batch_size,model_info,critic_ensemble_num=1,
            actor_ensemble_num=1,dataset=None,algo_name='sac_n',**kwargs):


        super().__init__(obs_dims=obs_dims,action_dims=action_dims,actor_lr=actor_lr,
                        critic_lr=critic_lr,gamma=gamma,tau=tau,mem_size=mem_size,
                        batch_size=batch_size,dataset=dataset,algo_name=algo_name,
                        model_info=model_info,**kwargs)

        self.learnable_temperature = kwargs.get('learnable_temperature',True)

        print(self.critic_factor)

        self.critic = ContinuousVectorisedCritic(obs_dims=obs_dims,
                                             action_dims=action_dims,
                                             model_info=model_info,
                                             ensemble_num=critic_ensemble_num,
                                             algo_name=algo_name,
                                             critic_factor=self.critic_factor)
            
        self.target_critic = deepcopy(self.critic)


        self.actor = Actor(obs_dims=obs_dims,
        		    	     action_dims=action_dims,
        			    	 min_val=self.min_action_val,
	    		             max_val=self.max_action_val)

        self.critic_optimiser = self.optimiser(self.critic.parameters(),lr=critic_lr)
        self.actor_optimiser = self.optimiser(self.actor.parameters(),lr=actor_lr)

        if self.learnable_temperature: 
            self.log_alpha = torch.zeros(1, requires_grad=True,device=self.device)
            self.alpha_optimiser = self.optimiser([self.log_alpha],lr=3e-4)
            self.target_entropy = -torch.prod(torch.tensor(action_dims)).item()
        else:
            self.log_alpha = torch.log(torch.tensor(0.2))
            self.learnable_temperature=False


        if kwargs.get('use_data',False):
            self.replay_buffer.store_offline_data(dataset=dataset,
                                                normalise_state=kwargs['normalise_state'],
                                                env_id=self.env_id)
        else:
            print('not using dataset')

        self.move_to(kwargs['device'])
        self.batch_size = 1, batch_size

        self.total_it = 0

        self.actor_loss = 0


    def move_to(self, device):

        super().move_to(device)

        self.critic.to(device=device)
        self.target_critic.to(device=device)
        self.actor.to(device=device)

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def choose_action(self,state,**kwargs):
        state = torch.tensor(state,dtype=torch.float).to(self.device)
        action_info = self.actor.sample(state,**kwargs)
        
        return action_info

    


    def update_critic(self, samples, iter_no=None):


        if iter_no is None:
            iter_no = self.total_it

        states, next_states, diff_states, actions, rewards, done_batch = samples

        done_batch = done_batch[0].unsqueeze(-1)
        rewards = rewards[0].unsqueeze(-1)

        with torch.no_grad():
            next_action_info =  self.choose_action(next_states)

            next_action_values = self.target_critic(next_states,next_action_info['action'])
            next_action_values = self._calc_critic_value(next_action_values,next_action_info['log_prob'],
                                                        done_batch)

            est_critic_val = rewards + self.gamma*next_action_values
            est_critic_val = est_critic_val.unsqueeze(0)

        q_values = self.critic(states,actions)
        q_loss = F.mse_loss(q_values,est_critic_val)

        self.critic_optimiser.zero_grad()
        q_loss.backward()
        self.critic_optimiser.step()

        self.log_dict['train/critic_loss'] = q_loss.item()

        return q_loss


    def update_actor(self, samples, iter_no=None):

        if iter_no is None:
            iter_no = self.total_it

        states, next_states, diff_states, actions, rewards, done_batch = samples

        new_action_info = self.choose_action(states)

        q_values = self._get_actor_critic_val(states, new_action_info)

        actor_loss = self._calc_actor_loss(q_values, new_action_info)


        self.actor_optimiser.zero_grad()
        actor_loss.backward()
        self.actor_optimiser.step()



        if self.learnable_temperature:
            alpha_loss = self._calc_alpha_loss(new_action_info)
            self.alpha_optimiser.zero_grad()
            alpha_loss.backward()
            self.alpha_optimiser.step()
            self.log_dict['train/alpha_loss'] = alpha_loss.item()
            self.log_dict['train/alpha'] = self.alpha.item()

            
        self.log_dict['train/actor_loss'] = actor_loss.item()
        return actor_loss


    def _calc_alpha_loss(self, action_info):
        alpha_loss =  -(self.alpha * (action_info['log_prob'] + self.target_entropy).detach()).mean()
        return alpha_loss 

    def _calc_actor_loss(self, critic_val, action_info, **kwargs):

        actor_loss = (self.alpha.detach()*action_info['log_prob'] - critic_val.unsqueeze(-1))
        actor_loss = actor_loss.mean()

        return actor_loss

    def _get_actor_critic_val(self, states, action_info):

        q_vals = self.critic(states,action_info['action'])
        critic_val = self._calc_critic_value(q_vals,action_info['log_prob'])

        return critic_val

    def learn(self, sample_range=None, dep_targ=True, samples=None):

        self.total_it +=1
        actor_loss = None

        if self.replay_buffer.mem_cntr<self.replay_buffer.batch_size:
            return

        
        if samples is None:
            *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                            sample_range=sample_range,
                                                            batch_size=self.batch_size)

        *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                    sample_range=sample_range,
                                                    batch_size=self.batch_size)



        critic_loss = self.update_critic(samples)
        
        if self.total_it % self.policy_update_freq == 0:
            actor_loss = self.update_actor(samples)


        if self.total_it%1000000 == 0 and self.model_save:
            self.save_model()


        soft_update(self.target_critic,self.critic,tau=self.tau)

        return critic_loss, actor_loss

    def save_model(self):
        model_path = self.create_filepath(path='models')

        model_path += ('-'+str(self.total_it))

        if self.online:
            model_path += '-online'

        self.model_path = model_path

        print(f'Saving models to {model_path}')
        torch.save({'actor_state_dict':self.actor.state_dict(),
                    'critic_state_dict':self.critic.state_dict(),
                    'target_critic_state_dict':self.target_critic.state_dict(),
                    'critic_optimiser_state_dict':self.critic_optimiser.state_dict(),
                    'actor_optimiser_state_dict':self.actor_optimiser.state_dict(),
                    'alpha_optimiser_state_dict':self.alpha_optimiser.state_dict(),
                    'log_alpha':self.log_alpha},
                   model_path)
        return model_path
    def load_model(self, iter_no, model_path=None, evaluate=False):
        if model_path is None:
            model_path = self.create_filepath(path='models')
        print(f"\nLoading models from {model_path}...")

        model_path += ('-'+str(iter_no))

        if self.online:
            model_path += '-online'


        model_checkpoint = torch.load(model_path)

        self.actor.load_state_dict(model_checkpoint['actor_state_dict'])
        self.critic.load_state_dict(model_checkpoint['critic_state_dict'])
        self.target_critic.load_state_dict(model_checkpoint['target_critic_state_dict'])

        self.critic_optimiser.load_state_dict(model_checkpoint['critic_optimiser_state_dict'])
        self.actor_optimiser.load_state_dict(model_checkpoint['actor_optimiser_state_dict'])
        self.alpha_optimiser.load_state_dict(model_checkpoint['alpha_optimiser_state_dict'])

        self.log_alpha = model_checkpoint['log_alpha']

        if evaluate:
            self.actor.eval()
            self.critic.eval()
            self.target_critic.eval()
        else:
            self.actor.train()
            self.critic.train()
            self.target_critic.train()
