import wandb

import torch
import torch.optim as optim
import torch.nn.functional as F

from torch.distributions.categorical import Categorical

from copy import deepcopy
import numpy as np
import os, time

from reinmax import reinmax

from utils.critic import DiscreteCritic
from utils.actor import DiscreteActor

from utils.misc import soft_update
from utils.base_agent import BaseAgent
from utils.replay_buffer import ReplayBuffer



class Agent(BaseAgent):

    def __init__(self, obs_dims, action_dims, critic_lr, actor_lr, gamma, tau, mem_size,
                    batch_size, critic_factor, model_info=None,dataset=None,
                    algo_name='maac', **kwargs):

        super().__init__(obs_dims=obs_dims, action_dims=action_dims, gamma=gamma,
                        tau=tau, algo_name=algo_name, critic_factor=critic_factor,
                        **kwargs)

        self.epsilon = kwargs['epsilon']
        self.epsilon_min = kwargs['epsilon_min']
        self.exploration_decay = kwargs['exploration_decay']
        self.bc_alpha = 0 if self.online else kwargs['bc_alpha']
        use_data = False if self.online else True
        self.softmax_tau = kwargs['softmax_tau']
        self.transform_reward = kwargs['transform_reward']


        self._action_space = kwargs['env'].action_space

        self.action_bins = kwargs['action_bins']

        self.critic = DiscreteCritic(state_dim=obs_dims,
                                    hidden_dim=kwargs['critic_hidden_dim'],
                                    action_dims=action_dims,
                                    action_bins=kwargs['action_bins'],
                                    ensemble_size=critic_factor,
                                    constrained_dim=kwargs['action_constrained_dim']
                                    )


        self.target_critic = deepcopy(self.critic)
        self.critic_optimiser = self.optimiser(self.critic.parameters(),lr=critic_lr)


        self.actor =  DiscreteActor(state_dim=obs_dims,
                            hidden_dim=kwargs['actor_hidden_dim'],
                            action_bins=kwargs['action_bins'],
                            action_dims=action_dims,
                            )

        self.target_actor = deepcopy(self.actor)
        self.actor_optimiser = self.optimiser(self.actor.parameters(),lr=actor_lr)

        self.replay_buffer = ReplayBuffer(batch_size=batch_size,
                                                obs_dims=obs_dims,
                                                action_dims=action_dims,
                                                dataset=dataset,
                                                mem_size=mem_size, 
                                                discrete_action=True,
                                                use_data=use_data,
                                                normalise_state=kwargs['normalise_state']
                                                )


        self.move_to()
        self.batch_size = batch_size

        self.total_it = 0


        if wandb.run is not None:
            wandb.define_metric('train/critic_loss',step_metric='total_step')
            wandb.define_metric('train/critic_values',step_metric='total_step')




    def move_to(self):
        super().move_to(self.device)
        self.critic.to(device=self.device)
        self.target_critic.to(device=self.device)
        self.actor.to(device=self.device)
        self.target_actor.to(device=self.device)

    

    def choose_action(self, state, deterministic=False, **kwargs):

        state = torch.tensor(state,dtype=torch.float).to(self.device).unsqueeze(0)

        act_probs, _ = self.actor(state)

        if deterministic or self.rng.uniform()>self.epsilon:
            act = act_probs.argmax(dim=-1)
        else:
            act = torch.tensor(self._action_space.sample())
            self.epsilon = max(self.epsilon*self.exploration_decay,self.epsilon_min)

        return {'action':act.squeeze()}


    def update_critic(self, samples, iter_no=None):

        states, next_states, actions, rewards, done_batch = samples

        done_batch = done_batch.permute(1,0).repeat(1,self.action_dim)
        rewards = rewards.permute(1,0).repeat(1,self.action_dim)

        if iter_no is None:
            iter_no = self.total_it


        with torch.no_grad():


            next_action_logits, next_action_log_probs = self.target_actor(next_states)
            next_action_probs = F.softmax(self.softmax_tau*next_action_logits,dim=-1)
            next_actions = torch.multinomial(next_action_probs.flatten(0,1), 1)
            next_actions = next_actions.reshape(self.batch_size,self.action_dim)

            next_action_values = self.target_critic(next_states,next_actions)

            next_action_values = next_action_values.mean(dim=0).max(dim=-1).values

            next_action_values[done_batch] = 0

            est_critic_val = rewards + (self.gamma**self.n_steps)*next_action_values





        
        q_values = self.critic(states,actions)

        gather_act = actions.unsqueeze(0).repeat(self.critic_factor,1,1).unsqueeze(-1)
        q_values = q_values.gather(-1,gather_act).squeeze(-1)

        
        critic_loss = F.huber_loss(q_values,est_critic_val)

        self.critic_optimiser.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 40)
        self.critic_optimiser.step()


        self.log_dict['train/critic_loss'] = critic_loss.item()
        self.log_dict['train/critic_values'] = q_values.mean().item()


        return critic_loss



    def update_actor(self, samples, iter_no=None):
        if iter_no is None:
            iter_no = self.total_it

        states, next_states, actions, rewards, done_batch = samples


        new_action_logits, new_action_log_probs = self.actor(states)

        new_action_probs = F.softmax(new_action_logits,dim=-1)

        new_actions = new_action_logits.argmax(dim=-1)
        
        with torch.no_grad():
            q_vals = self.critic(states, new_actions)
            q_vals = q_vals.mean(dim=0)
            max_act = q_vals.argmax(dim=-1)

        


        actor_loss = F.cross_entropy(new_action_logits.flatten(0,1),max_act.flatten(0,1)) 

        if not self.online:
            regularisation = F.cross_entropy(new_action_logits.flatten(0,1),actions.flatten(0,1))
            actor_loss = F.cross_entropy(new_action_logits.flatten(0,1),max_act.flatten(0,1))
            actor_loss = self.bc_alpha*actor_loss/actor_loss.detach() + regularisation


        self.actor_optimiser.zero_grad()
        actor_loss.backward()
        self.actor_optimiser.step()

        self.log_dict['train/actor_loss'] = actor_loss.item()
        return actor_loss


    def learn(self, sample_range=None, **kwargs):

        self.total_it += 1
        actor_loss = None

        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return


        *samples, batch_idx = self.sample(rng=self.rng, sample_range=sample_range,
                                            batch_size=self.batch_size)


        critic_loss = self.update_critic(samples)
            

        if self.total_it % self.policy_update_freq ==0:
            actor_loss = self.update_actor(samples)
            soft_update(self.target_actor,self.actor,tau=self.tau)

        soft_update(self.target_critic,self.critic,tau=self.tau)


        if self.total_it%1000000 == 0 and self.model_save:
            self.save_model()


        return critic_loss, actor_loss

