import torch
import torch.optim as optim
import torch.nn.functional as F

from copy import deepcopy
import numpy as np
import os, time

from utils.decomp_networks import DecoupledQNetwork
from utils.misc import soft_update
from utils.base_agent import BaseAgent
from utils.replay_buffer import ReplayBuffer



class Agent(BaseAgent):

    def __init__(self, obs_dims, action_dims, critic_lr, actor_lr, gamma, tau, mem_size,
                batch_size, critic_factor, model_info=None,dataset=None,
                algo_name='decqn_action', **kwargs):

        super().__init__(obs_dims=obs_dims, action_dims=action_dims, gamma=gamma,tau=tau,
                        algo_name=algo_name, **kwargs)


        self.epsilon = kwargs.get('epsilon',1)
        self.epsilon_min = kwargs.get('epsilon_min',0.05)
        self.exploration_decay = kwargs.get('exploration_decay',0.999) #95)

        self._action_space = kwargs['env'].action_space

        self.critic = DecoupledQNetwork(state_dim=obs_dims,
                                        hidden_dim=512,
                                        num_states=3, ###number of bins per sub_action
                                        num_heads=action_dims)


        self.target_critic = deepcopy(self.critic)
        self.critic_optimiser = self.optimiser(self.critic.parameters(),lr=critic_lr)

        self.replay_buffer = ReplayBuffer(batch_size=batch_size,
                                                obs_dims=obs_dims,
                                                action_dims=action_dims,
                                                dataset=dataset,
                                                mem_size=mem_size,
                                                discrete_eps=kwargs['discrete_eps'],
                                                discrete_bins=kwargs['discrete_bins'],
                                                discrete_action=True,
                                                )


        self.move_to()
        self.batch_size = batch_size

        self.total_it = 0

    def move_to(self):
        super().move_to(self.device)
        self.critic.to(device=self.device)
        self.target_critic.to(device=self.device)

    def choose_action(self, state, **kwargs):

        
        state = torch.tensor(state,dtype=torch.float).to(self.device)
        state= state.squeeze(0)
        if kwargs.get('deterministic'): # or self.rng.uniform()>self.epsilon:
            deterministic = True
        else:
            deterministic = False

        max_subaction_idx = self.critic.sample(state,deterministic=deterministic)['q_vals'].argmax(dim=-1)

        action = {'action':max_subaction_idx}
        return action


    def update_critic(self, samples, iter_no=None):

        states, next_states, diff_states, actions, rewards, done_batch = samples

        rewards = rewards.permute((1,0)) #.repeat(1,self.action_dim)
        done_batch = done_batch.permute((1,0)) #.repeat(1,self.action_dim)

        if iter_no is None:
            iter_no = self.total_it

        dist = self.critic(states)
        utility_values, utility_std = dist.mean, dist.stddev
        selected_utility_values = utility_values.gather(-1,actions.unsqueeze(-1)).squeeze()
        q_values = selected_utility_values.mean(dim=1)
        selected_std_values = utility_std.gather(-1,actions.unsqueeze(-1)).squeeze()
        q_std = selected_std_values.sum(dim=1)
        q_dist = torch.distributions.Normal(selected_utility_values,selected_std_values)

        with torch.no_grad():

            target_dist = self.target_critic(next_states)
            target_utilities, targest_std = target_dist.mean, target_dist.stddev
            
            target_q_vals = target_utilities.max(dim=-1).values.mean(dim=1)


            target_q_vals[done_batch.squeeze()] = 0

            est_critic_val = rewards.squeeze() + (self.gamma**self.n_steps)*target_q_vals


        if self.total_it%10000 == 0:
            print(est_critic_val.mean(),q_values.mean())

       #critic_loss = F.mse_loss(q_values,est_critic_val)
        critic_loss = F.huber_loss(q_values,est_critic_val)
       #critic_loss = -q_dist.log_prob(est_critic_val).mean()

       #log_prob = (q_values-est_critic_val)**2/(2*(q_std**2))  + torch.log(q_std)
       #critic_loss = log_prob.mean()

        self.log_dict['train/critic_loss'] = critic_loss.item()

       #if self.total_it%1000 == 0:
       #    print(torch.log(q_std).mean(), ((q_values-est_critic_val)**2/(2*(q_std**2))).mean())



        self.critic_optimiser.zero_grad()
        critic_loss.backward()
        self.critic_optimiser.step()

        return critic_loss

    def learn(self,**kwargs):

        self.total_it +=1
        critic_loss = None

        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return


        *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                        batch_size=self.batch_size)



        critic_loss = self.update_critic(samples)
        soft_update(self.target_critic, self.critic, tau=self.tau)


        if self.total_it%100000 == 0 and self.model_save:
            self.save_model()

        return critic_loss


    def save_model(self):

        model_path = self.create_filepath(path='models')

        model_path += ('-'+str(self.total_it))

        if self.online:
            model_path += '-online'

        self.model_path = model_path

        print(f'Saving models to {model_path}')
        torch.save({'critic_state_dict':self.critic.state_dict(),
                    'target_critic_state_dict':self.target_critic.state_dict(),},
                   model_path)
        return model_path

    def load_model(self, iter_no, model_path=None, evaluate=False, online=False):
        if model_path is None:
            model_path = self.create_filepath(path='models')
        print(f"\nLoading models from {model_path}...")

        model_path += ('-'+str(iter_no))

        if online:
            model_path += '-online'

        model_checkpoint = torch.load(model_path)

        self.critic.load_state_dict(model_checkpoint['critic_state_dict'])
        self.target_critic.load_state_dict(model_checkpoint['target_critic_state_dict'])

        if evaluate:
            self.critic.eval()
            self.target_critic.eval()
        else:
            self.critic.train()
            self.target_critic.train()

