import wandb

import torch
import torch.optim as optim
import torch.nn.functional as F

from torch.distributions.categorical import Categorical

from copy import deepcopy
import numpy as np
import os, time

from utils.critic import Critic
from utils.actor import Actor

from utils.misc import soft_update
from utils.base_agent import BaseAgent
from utils.replay_buffer import ReplayBuffer



class Agent(BaseAgent):

    def __init__(self, obs_dims, action_dims, critic_lr, actor_lr, gamma, tau, mem_size,
                    batch_size, critic_factor, model_info=None,dataset=None,
                    algo_name='maddpg', **kwargs):

        super().__init__(obs_dims=obs_dims, action_dims=action_dims, gamma=gamma,
                        tau=tau, algo_name=algo_name, critic_factor=critic_factor,
                        **kwargs)
        
        print(self.n_steps)

        self.epsilon = kwargs['epsilon']
        self.epsilon_min = kwargs['epsilon_min']
        self.exploration_decay = kwargs['exploration_decay']
        self.bc_alpha = 0 if self.online else kwargs['bc_alpha']
        self.transform_reward = False

        self.sample_type = kwargs['sample_type']

        self.policy_update_freq = kwargs['policy_update_freq']

        self._action_space = kwargs['env'].action_space
        self.action_bins = kwargs['action_bins']

        self.critic = Critic(state_dim=obs_dims,
                            hidden_dim=kwargs['critic_hidden_dim'],
                            action_dims=action_dims,
                            action_bins=kwargs['action_bins'],
                            ensemble_size=critic_factor)


        self.target_critic = deepcopy(self.critic)
        self.critic_optimiser = self.optimiser(self.critic.parameters(),lr=critic_lr)

        
        self.actor =  Actor(state_dim=obs_dims,
                            hidden_dim=kwargs['actor_hidden_dim'],
                            action_bins=kwargs['action_bins'],
                            action_dims=action_dims,
                            )

        self.target_actor = deepcopy(self.actor)
        self.actor_optimiser = self.optimiser(self.actor.parameters(),lr=actor_lr)

        self.replay_buffer = ReplayBuffer(batch_size=batch_size,
                                                obs_dims=obs_dims,
                                                action_dims=action_dims,
                                                dataset=dataset,
                                                mem_size=mem_size,
                                                discrete_action=True,
                                                use_data=False,
                                                normalise_state=kwargs['normalise_state']
                                                )


        self.move_to()
        self.batch_size = batch_size

        self.total_it = 0


        if wandb.run is not None:
            wandb.define_metric('train/critic_loss',step_metric='total_step')
            wandb.define_metric('train/critic_values',step_metric='total_step')




    def move_to(self):
        super().move_to(self.device)
        self.critic.to(device=self.device)
        self.target_critic.to(device=self.device)
        self.actor.to(device=self.device)
        self.target_actor.to(device=self.device)


    def choose_action(self, state, deterministic=False, **kwargs):
        
        state = torch.tensor(state,dtype=torch.float).to(self.device)
       #act = self.actor(state)

        act_logits = self.actor(state)
        if deterministic or self.rng.uniform()>self.epsilon:
            act = act_logits.argmax(dim=-1)
        else:
            act = torch.tensor(self._action_space.sample())
            self.epsilon = max(self.epsilon*self.exploration_decay,self.epsilon_min)

        return {'action':act.squeeze()}
        

    def update_critic(self, samples, iter_no=None):

        states, next_states, actions, rewards, done_batch = samples
        
        done_batch = done_batch.permute(1,0).repeat(1,self.action_dim)
        rewards = rewards.permute(1,0).repeat(1,self.action_dim)


        if iter_no is None:
            iter_no = self.total_it


        with torch.no_grad():


            next_action_logits = self.target_actor(next_states)
            next_actions = F.gumbel_softmax(next_action_logits, hard=True)
            

            next_action_values = self.target_critic(next_states,next_actions.flatten(1,2))

            next_action_values = next_action_values.mean(dim=0)

            next_action_values[done_batch] = 0

            est_critic_val = rewards + (self.gamma**self.n_steps)*next_action_values

        actions = F.one_hot(actions,num_classes=self.action_bins)
        q_values = self.critic(states,actions.flatten(1,2))


        critic_loss = F.huber_loss(q_values,est_critic_val)


        self.critic_optimiser.zero_grad()
        critic_loss.backward()
        self.critic_optimiser.step()



        self.log_dict['train/critic_loss'] = critic_loss.item()
        self.log_dict['train/critic_values'] = q_values.mean().item()

        return critic_loss

    def update_actor(self, samples, iter_no=None):
        if iter_no is None:
            iter_no = self.total_it

        states, next_states, actions, rewards, done_batch = samples

        action_logits = self.actor(states)
        new_actions = F.gumbel_softmax(action_logits, hard=True)
        
        q_vals = self.critic(states, new_actions.flatten(1,2))

        actor_loss = -q_vals.mean()

        if not self.online:
           regularisation = F.cross_entropy(action_logits.flatten(0,1),actions.flatten(0,1))
           actor_loss = self.bc_alpha*actor_loss/actor_loss.detach() + regularisation


        self.actor_optimiser.zero_grad()
        actor_loss.backward()
        self.actor_optimiser.step()

        self.log_dict['train/actor_loss'] = actor_loss.item()

        return actor_loss



    def learn(self, sample_range=None, **kwargs):

        self.total_it += 1
        actor_loss = None

        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return


        *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                        sample_range=sample_range,
                                        batch_size=self.batch_size)



        critic_loss = self.update_critic(samples)


        if self.total_it % self.policy_update_freq ==0:
            actor_loss = self.update_actor(samples)

            soft_update(self.target_actor,self.actor,tau=self.tau)

        soft_update(self.target_critic,self.critic,tau=self.tau)


        if self.total_it%1000000 == 0 and self.model_save:
            self.save_model()


        return critic_loss, actor_loss

