import wandb

import torch
import torch.optim as optim
import torch.nn.functional as F

from torch.distributions.categorical import Categorical

from copy import deepcopy
import numpy as np
import os, time

from utils.decomp_networks import EnsembleDecoupledQNetwork
from utils.misc import soft_update
from utils.base_agent import BaseAgent
from utils.replay_buffer import ReplayBuffer

from utils.vectorised_networks import Revalued


class Agent(BaseAgent):

    def __init__(self, obs_dims, action_dims, critic_lr, actor_lr, gamma, tau, mem_size,
                batch_size, critic_factor, model_info=None,dataset=None,
                algo_name='true_revalued', **kwargs):

        super().__init__(obs_dims=obs_dims, action_dims=action_dims, gamma=gamma,
                        tau=tau, algo_name=algo_name, critic_factor=critic_factor,
                        **kwargs)

        
        self.epsilon = kwargs['epsilon']
        self.epsilon_min = kwargs['epsilon_min']
        self.exploration_decay = kwargs['exploration_decay']
        print('mem_size:', mem_size)
        print('epsilon_min:', self.epsilon_min)
        self.transform_reward = False

        self._action_space = kwargs['env'].action_space


        self.critic = EnsembleDecoupledQNetwork(state_dim=obs_dims,
                                        hidden_dim=kwargs['critic_hidden_dim'],
                                        action_bins=kwargs['action_bins'],  ###number of bins per sub_action
                                        action_dims=action_dims, ###number of sub_actions
                                        ensemble_size=critic_factor)


        self.target_critic = deepcopy(self.critic)
        self.critic_optimiser = self.optimiser(self.critic.parameters(),lr=critic_lr)
        

        self.replay_buffer = ReplayBuffer(batch_size=batch_size,
                                        obs_dims=obs_dims,
                                        action_dims=action_dims,
                                        dataset=dataset,
                                        mem_size=mem_size,
                                        discrete_action=True,
                                        use_data=False
                                        )



        assert self.online, 'only use online'

        self.move_to()
        self.batch_size = batch_size

        self.total_it = 0 


        if wandb.run is not None:
            wandb.define_metric('train/critic_loss',step_metric='total_step')
            wandb.define_metric('train/critic_values',step_metric='total_step')




    def move_to(self):
        super().move_to(self.device)
        self.critic.to(device=self.device)
        self.target_critic.to(device=self.device)

    def choose_action(self, state, **kwargs):


        state = torch.tensor(state,dtype=torch.float).to(self.device).unsqueeze(0)
        action_vals =  self.critic(state)

        if kwargs.get('deterministic') or self.rng.uniform()>self.epsilon:
            action = action_vals.mean(dim=1).argmax(dim=-1)
        else:
            action = torch.tensor(self._action_space.sample())
            self.epsilon = max(self.epsilon*self.exploration_decay,self.epsilon_min)

        action_info = {'action':action}
        return action_info


    def update_critic(self, samples, iter_no=None):

        states, next_states, actions, rewards, done_batch = samples


        rewards = rewards.permute((1,0))
        done_batch = done_batch.permute((1,0))

        if iter_no is None:
            iter_no = self.total_it

        utility_values = self.critic(states)

        rep_actions = actions.unsqueeze(1).repeat(1,self.critic_factor,1)
        selected_utility_values = utility_values.gather(-1,rep_actions.unsqueeze(-1))
        q_values = selected_utility_values.mean(dim=2).squeeze()


        with torch.no_grad():

            target_values = self.target_critic(next_states)

            target_q_vals =  target_values.mean(dim=1).max(dim=-1).values.mean(dim=1,keepdim=True)

            target_q_vals[done_batch] = 0
            est_critic_val = rewards + (self.gamma**self.n_steps)*target_q_vals
        
            if self.critic_factor == 1:
                est_critic_val = est_critic_val.squeeze()
        
            diff = (selected_utility_values - est_critic_val.unsqueeze(1).repeat(1,self.critic_factor,self.action_dim).unsqueeze(-1)).abs()
            regulariser_weight = 1 - (-diff).exp()
            selected_target_utility = self.target_critic(states).gather(-1,rep_actions.unsqueeze(-1))

       #if self.total_it%1000 == 0:
       #    print(q_values.mean(),est_critic_val.mean())


        unweighted_regulariser = F.huber_loss(selected_utility_values,selected_target_utility,reduction='none')

        regulariser_loss = (regulariser_weight*unweighted_regulariser).mean()

        critic_loss = F.huber_loss(q_values,est_critic_val) 

        total_loss = critic_loss + 0.5*regulariser_loss


        self.log_dict['train/critic_loss'] = critic_loss.item()
        self.log_dict['train/total_loss'] = total_loss.item()
        self.log_dict['train/regulariser_loss'] = regulariser_loss.item()
        self.log_dict['train/critic_values'] = q_values.mean().item()


        self.critic_optimiser.zero_grad()
        total_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 40)
        self.critic_optimiser.step()

        
        return critic_loss


    def get_reg_loss(self, states, actions, utilities, targets):
        if self.grad_steps > self.burnin:
        
            with torch.no_grad():
                diff = (utilities - targets.unsqueeze(dim=-1).repeat(1, 1, self.num_heads)) / self.temp
                diff = diff.abs()
                weights = 1 - (-diff).exp()
                old_utilities = self.critic_target.forward(states).gather(-1, actions.unsqueeze(dim=-1)).squeeze(dim=-1)

            unweighted_loss = self.huber(utilities, old_utilities)
            loss = self.beta * (unweighted_loss * weights).mean()
            self.anneal_beta()
            return loss
        else:
           return 0

    def learn(self,**kwargs):

        self.total_it +=1
        critic_loss = None

        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return


        *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                        batch_size=self.batch_size)

        critic_loss = self.update_critic(samples)
        soft_update(self.target_critic, self.critic, tau=self.tau)

        if self.total_it%100000 == 0 and self.model_save:
            self.save_model()

        return critic_loss

    def save_model(self):

        model_path = self.create_filepath(path='models')

        model_path += ('-'+str(self.total_it))


        if self.online:
            model_path += '-online'

        self.model_path = model_path

        print(f'Saving models to {model_path}')
        torch.save({'critic_state_dict':self.critic.state_dict(),
                    'target_critic_state_dict':self.target_critic.state_dict(),},
                   model_path)
        return model_path


    def load_model(self, iter_no, model_path=None, evaluate=False, online=False):
        if model_path is None:
            model_path = self.create_filepath(path='models')
        print(f"\nLoading models from {model_path}...")

        model_path += ('-'+str(iter_no))

        if online:
            model_path += '-online'

        model_checkpoint = torch.load(model_path)

        self.critic.load_state_dict(model_checkpoint['critic_state_dict'])
        self.target_critic.load_state_dict(model_checkpoint['target_critic_state_dict'])

        if evaluate:
            self.critic.eval()
            self.target_critic.eval()
        else:
            self.critic.train()
            self.target_critic.train()

