import wandb

import torch
import torch.optim as optim
import torch.nn.functional as F

from torch.distributions.categorical import Categorical

from copy import deepcopy
import numpy as np
import os, time

from utils.decomp_networks import EnsembleDecoupledQNetwork, DecoupledQNetwork
from utils.misc import soft_update
from utils.base_agent import BaseAgent
from utils.replay_buffer import ReplayBuffer

from utils.vectorised_networks import Revalued

from rl_algos.single_agent.TD3.agent import Agent as BaselineTD3Agent

class Agent(BaseAgent):

    def __init__(self, obs_dims, action_dims, critic_lr, actor_lr, gamma, tau, mem_size,
                batch_size, critic_factor, model_info=None,dataset=None,
                algo_name='revalued_action', **kwargs):

        super().__init__(obs_dims=obs_dims,gamma=gamma,tau=tau,
                        algo_name=algo_name, critic_factor=critic_factor,
                        **kwargs)

        self.epsilon = kwargs.get('epsilon',1)
        self.epsilon_min = kwargs.get('epsilon_min',0.05)
        self.exploration_decay = kwargs.get('exploration_decay',0.9999) 

        self._action_space = kwargs['env'].action_space
        self.noise_clip = kwargs['noise_clip']
        self.policy_noise_std = kwargs['policy_noise_std']
        self.sample_type = kwargs['sample_type']

        self.bc_alpha = kwargs['bc_alpha']





        self.critic = EnsembleDecoupledQNetwork(state_dim=obs_dims,
                                        hidden_dim=512,
                                        num_actions=kwargs['action_bins'],  ###number of bins per sub_action
                                        num_heads=action_dims, ###number of sub_actions
                                        ensemble_size=critic_factor)


        self.target_critic = deepcopy(self.critic)
        self.critic_optimiser = self.optimiser(self.critic.parameters(),lr=critic_lr)



        replay_mem_size = kwargs['replay_mem_size'] if kwargs.get('replay_mem_size',None) else mem_size
        dataset_mem_size = kwargs['dataset_mem_size'] if kwargs.get('dataset_mem_size',None) else mem_size
        print(self.n_steps,'n_steps')

        self.priority=False
        self.replay_buffer = ReplayBuffer(batch_size=batch_size,
                                        obs_dims=obs_dims,
                                        action_dims=action_dims,
                                        dataset=dataset,
                                        mem_size=replay_mem_size,
                                        discrete_eps=kwargs['discrete_eps'],
                                        discrete_bins=kwargs['discrete_bins'],
                                        discrete_action=True,
                                        use_data=kwargs.get('use_data',False)
                                        )
        print(replay_mem_size,dataset_mem_size)

        ###instead of storing n_step transitions this is used to store 1 step transitions
        self.dataset_buffer = ReplayBuffer(batch_size=batch_size,
                                        obs_dims=obs_dims,
                                        action_dims=action_dims,
                                        dataset=dataset,
                                        mem_size=dataset_mem_size,
                                        discrete_eps=kwargs['discrete_eps'],
                                        discrete_bins=kwargs['discrete_bins'],
                                        discrete_action=True,
                                        )


        if kwargs.get('use_data', False):
            self.replay_buffer.store_offline_data(dataset=dataset,
                                                normalise_state=kwargs['normalise_state'])

            self.dataset_buffer.store_offline_data(dataset=dataset,
                                                normalise_state=kwargs['normalise_state'])
        else:
            print('not using dataset')
        
        print(self.dataset_buffer.mem_size,self.replay_buffer.mem_size)

        assert self.online, 'only use online'

        self.move_to()
        self.batch_size = batch_size

        self.total_it = 0 


        if wandb.run is not None:
            wandb.define_metric('train/critic_loss',step_metric='total_step')
            wandb.define_metric('train/critic_values',step_metric='total_step')




    def move_to(self):
        super().move_to(self.device)
        self.dataset_buffer.to(device=self.device)
        self.critic.to(device=self.device)
        self.target_critic.to(device=self.device)

    def choose_action(self, state, diff_state=None, prev_diff_state=None, **kwargs):


        state = torch.tensor(state,dtype=torch.float).to(self.device).squeeze(0)
        action_vals =  self.critic(state)

        if kwargs.get('deterministic') or self.rng.uniform()>self.epsilon:
            if self.sample_type == 'mean':
                action = action_vals.mean(dim=1).argmax(dim=-1)
            elif self.sample_type == 'double_q':
                action = action_vals.mean(dim=1).argmax(dim=-1)
        else:
            action = torch.tensor(self._action_space.sample())
            self.epsilon = max(self.epsilon*self.exploration_decay,self.epsilon_min)

        action_info = {'action':action}
        return action_info


    def update_critic(self, samples, iter_no=None, weights=None):


        states, next_states, diff_states, actions, rewards, done_batch = samples

        rewards = rewards.permute((1,0))
        done_batch = done_batch.permute((1,0))

        if iter_no is None:
            iter_no = self.total_it

        utility_values = self.critic(states)

        rep_actions = actions.unsqueeze(1).repeat(1,self.critic_factor,1)

        selected_utility_values = utility_values.gather(-1,rep_actions.unsqueeze(-1))
        q_values = selected_utility_values.mean(dim=2).squeeze()


        
        if self.sample_type == 'double_q':
            double_idx = self.rng.choice(self.critic_factor,size=(1,),replace=False)
            q_values = q_values[:,double_idx]
                


        with torch.no_grad():

            target_values = self.target_critic(next_states)

            if self.sample_type == 'mean':
                target_q_vals =  target_values.mean(dim=1).max(dim=-1).values.mean(dim=1,keepdim=True)
            elif self.sample_type == 'double_q':

                roll_idx = np.roll(np.arange(self.critic_factor),1)
                soft_diff_state_prob = utility_values.softmax(dim=-1)
                multi_dist = Categorical(soft_diff_state_prob)
                pred_next_action = multi_dist.sample()
                pred_next_action = pred_next_action[:,roll_idx]
                target_values = target_values.gather(-1,pred_next_action.unsqueeze(-1)).mean(dim=1)
                target_q_vals = target_values.mean(dim=1)


            target_q_vals[done_batch] = 0
            est_critic_val = rewards + (self.gamma**self.n_steps)*target_q_vals


        q_values = q_values.permute((1,0))
        est_critic_val = est_critic_val.permute((1,0))

        critic_loss = F.huber_loss(q_values,est_critic_val) 


        self.log_dict['train/critic_loss'] = critic_loss.item()
        self.log_dict['train/critic_values'] = q_values.mean().item()

        self.critic_optimiser.zero_grad()
        critic_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 40)
        self.critic_optimiser.step()

        return critic_loss


    def learn(self,**kwargs):

        self.total_it +=1
        critic_loss = None

        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return


        *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                        batch_size=self.batch_size)

        weights=None

        critic_loss = self.update_critic(samples,weights=weights)
        soft_update(self.target_critic, self.critic, tau=self.tau)


        if self.total_it%100000 == 0 and self.model_save:
            self.save_model()

        return critic_loss

    def save_model(self):

        model_path = self.create_filepath(path='models')

        model_path += ('-'+self.sample_type)
        model_path += ('-'+str(self.total_it))


        if self.online:
            model_path += '-online'

        self.model_path = model_path

        print(f'Saving models to {model_path}')
        torch.save({'critic_state_dict':self.critic.state_dict(),
                    'target_critic_state_dict':self.target_critic.state_dict(),},
                   model_path)
        return model_path


    def load_model(self, iter_no, model_path=None, evaluate=False, online=False):
        if model_path is None:
            model_path = self.create_filepath(path='models')
        print(f"\nLoading models from {model_path}...")

        model_path += ('-'+self.sample_type)
        model_path += ('-'+str(iter_no))

        if online:
            model_path += '-online'

        model_checkpoint = torch.load(model_path)

        self.critic.load_state_dict(model_checkpoint['critic_state_dict'])
        self.target_critic.load_state_dict(model_checkpoint['target_critic_state_dict'])

        if evaluate:
            self.critic.eval()
            self.target_critic.eval()
        else:
            self.critic.train()
            self.target_critic.train()

