import wandb

import torch
import torch.optim as optim
import torch.nn.functional as F

from utils.base_agent import BaseAgent
from utils.vectorised_networks import DDPGVectorisedActor

from utils.decomp_networks import DecoupledQNetwork
from utils.replay_buffer import ReplayBuffer 

from utils.vectorised_networks import DDPGVectorisedActor
from inverse_model import InverseModel, DiscreteInverseModel


class Agent(BaseAgent):
    def __init__(self, obs_dims, action_dims, bc_lr, batch_size, model_info,
                    algo_name='bc_diff_state', dataset=None, **kwargs):

        super().__init__(obs_dims=obs_dims,action_dims=action_dims,dataset=dataset,
                        batch_size=batch_size, algo_name=algo_name,
                        model_info=model_info, **kwargs)

        
        self.epsilon = kwargs.get('epsilon',1)
        self.epsilon_min = kwargs.get('epsilon_min',0.05)
        self.exploration_decay = kwargs.get('exploration_decay',0.999)
        self._action_space = kwargs['env'].action_space




        self.model = DecoupledQNetwork(state_dim=obs_dims,
                                        hidden_dim=512,
                                        num_states=kwargs['discrete_bins'],  ###number of bins per sub_action
                                        num_heads=obs_dims) ###number of sub_actions


        self.model_optimiser = self.optimiser(self.model.parameters(),lr=bc_lr)

        self.replay_buffer = ReplayBuffer(obs_dims=obs_dims,
                                            action_dims=action_dims,
                                            batch_size=batch_size,
                                            mem_size=kwargs['mem_size'],
                                            dataset=dataset,
                                            discrete_eps=kwargs['discrete_eps'],
                                            discrete_bins=kwargs['discrete_bins'],
                                            discrete_action=self.dm_suite)

        if self.dm_suite:
            self.actor = DiscreteInverseModel(state_dim=obs_dims,
                                            action_dim=action_dims,
                                            action_bins=kwargs['action_bins'],
                                            model_type='diff_state')
        else:
            self.actor = InverseModel(state_dim=obs_dims,
                                        action_dim=action_dims,
                                        max_action=self.max_action_val,
                                        model_type='diff_state')


        self.actor.load_model(self.env_id,f'1000000_big_eps-{kwargs["discrete_eps"]}_bins-{kwargs["discrete_bins"]}')

        if self.online:

            self.actor_optimiser = self.optimiser(self.actor.parameters(),lr=1e-3)

            self.load_model(1000000)
            
            self.dataset_buffer = ReplayBuffer(batch_size=batch_size,
                                            obs_dims=obs_dims,
                                            action_dims=action_dims,
                                            dataset=dataset,
                                            mem_size=kwargs['mem_size'],
                                            discrete_eps=kwargs['discrete_eps'],
                                            discrete_bins=kwargs['discrete_bins'],
                                            sample_returns=True
                                            )

            self.dataset_buffer.to(self.device)

            self.dataset_buffer.store_offline_data(dataset=dataset,
                                                normalise_state=kwargs['normalise_state'],
                                                env_id=self.env_id)

        else:
            self.replay_buffer.store_offline_data(dataset=dataset,
                                          normalise_state=kwargs['normalise_state'])



        self.move_to(self.device)
        self.batch_size = batch_size
        self.total_it = 0

        if wandb.run is not None:
            wandb.define_metric('train/loss',step_metric='total_step')


    def move_to(self, device):
        super().move_to(device)
        self.model.to(device=device)
        self.actor.to(device=device)

    def choose_action(self, state, diff_state=None,**kwargs):

        state = torch.tensor(state,dtype=torch.float).to(self.device)


        if diff_state is None:
            diff_state = self.choose_diff_state(state)
        else:
            diff_state = torch.tensor(diff_state,dtype=torch.float).to(self.device)

       #diff_state = diff_state - 1
        action = self.actor(state.squeeze(), diff_state.squeeze())

        if self.dm_suite:
            action = action.argmax(dim=-1)

        action_info = {'action':action}


        return action_info

    def choose_diff_state(self, state, prev_diff_state=None, **kwargs):
         
        state = torch.tensor(state,dtype=torch.float).to(self.device).squeeze(0)
        diff_state = self.model(state).argmax(dim=-1)

        return diff_state
    
    def update_actor(self, samples, iter_no=None, agent=None):


        if self.dm_suite:
            loss_func = F.cross_entropy
        else:
            loss_func = F.l1_loss

        if iter_no is None:
            iter_no = self.total_it


        states, _, diff_states, actions, _, _  = samples

        action_pred = self.actor(states,diff_states)

        if self.dm_suite:
            actor_loss = F.cross_entropy(action_pred.flatten(0,1), actions.flatten(0,1))
        else:
            actor_loss = F.l1_loss(action_pred, actions)

        if agent is not None:

            *dataset_samples, batch_idx = self.dataset_buffer.sample(rng=self.rng,
                                                            batch_size=self.batch_size)

            d_states, _, d_diff_states, _, _, _, d_returns  = dataset_samples


            if agent is not None:

                with torch.no_grad():
                    d_approx_actions = agent.choose_action(d_states.squeeze(),deterministic=True)['action']


                d_action_pred = self.actor(d_states,d_diff_states)
                if self.dm_suite:
                    d_actor_loss = F.cross_entropy(d_action_pred.flatten(0,1),d_approx_actions.flatten(0,1),
                                                    reduction='none')

                    d_actor_loss = d_actor_loss.reshape(-1,self.batch_size)
                else:
                    d_actor_loss = F.l1_loss(d_action_pred,d_approx_actions,reduction='none')
                actor_loss += d_actor_loss.mean()


        self.actor_optimiser.zero_grad()
        actor_loss.backward()
        self.actor_optimiser.step()

        if self.total_it%self.wandb_log_iter == 0:
           self.log_dict['train/inverse_loss'] = actor_loss.item()
           self.log_dict['train/online_buffer_inverse_loss'] = online_buffer_loss
           self.log_dict['train/offline_buffer_inverse_loss'] = d_actor_loss.mean().item()


        return actor_loss


       #states, next_states, diff_states, actions, rewards, done_batch = samples

       #diff_states = diff_states - 1
       #action_pred = self.actor(states,diff_states)

       #if self.dm_suite:
       #    actor_loss = F.cross_entropy(action_pred.flatten(0,1),actions.flatten(0,1))
       #else:
       #    actor_loss = F.mse_loss(action_pred, actions)

       #if self.total_it%self.wandb_log_iter == 0:
       #    self.log_dict['train/online_actor_loss'] = actor_loss.item()

       #self.actor_optimiser.zero_grad()
       #actor_loss.backward()
       #self.actor_optimiser.step()

       #return actor_loss


    def update_model(self, samples):
        states, next_states, diff_states, actions, rewards, done_batch = samples

        ### you use raw logits not softmax
        diff_state_logits = self.model(states)

        diff_state_logits = diff_state_logits.reshape(-1,self.replay_buffer.discrete_bins)
        diff_states = diff_states.reshape(-1)


        loss = F.cross_entropy(diff_state_logits,diff_states)
        self.log_dict['train/loss'] = loss.item()

        self.model_optimiser.zero_grad()
        loss.backward()
        self.model_optimiser.step()

       #if self.total_it%10000 == 0:
       #    diff_states = diff_states.reshape(states.shape)
       #    iss_act = self.choose_action(states,diff_states,deterministic=True)['action']

       #    state_noise = torch.randn(states.shape,dtype=torch.float).to(self.device)*0.1
       #    noisy_states = states+state_noise.clamp(-0.5,0.5)
       #    noisy_iss_act = self.choose_action(noisy_states,diff_states,deterministic=True)['action']
       #    print('\ntrue mae loss',F.l1_loss(actions,iss_act,reduction='none').mean(dim=-1).mean())

        return loss

    def learn(self, sample_range=None, **kwargs):

        self.total_it += 1
        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return

        *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                        sample_range=sample_range,
                                                        batch_size=self.batch_size)

        if self.online:
            loss = self.update_actor(samples)
        else:
            loss = self.update_model(samples)


        if self.total_it%1000000 == 0 and self.model_save:
            self.save_model()
        
        return loss


    def save_model(self):

        model_path = self.create_filepath(path='models')

        model_path+= ('-'+str(self.total_it))

        if self.online:
            model_path += '-online'

        self.model_path = model_path

        print(f'Saving models to {model_path}')
        torch.save({'model_state_dict':self.model.state_dict()},
                    model_path)

        return model_path

    def load_model(self, iter_no):

        cf = self.critic_factor
        self.critic_factor = 1

        model_path = self.create_filepath(path='models')
        model_path+= ('-'+str(iter_no))

        self.critic_factor = cf


        print(f"\nLoading models from {model_path}...")
        model_checkpoint = torch.load(model_path,map_location=self.device)

        self.model.load_state_dict(model_checkpoint['model_state_dict'])


