import wandb

import torch
import torch.optim as optim
import torch.nn.functional as F

import numpy as np

from utils.base_agent import BaseAgent
from utils.vectorised_networks import DDPGVectorisedActor
from utils.replay_buffer import ReplayBuffer 

from inverse_model import InverseModel, DiscreteInverseModel


class Agent(BaseAgent):
    def __init__(self, obs_dims, action_dims, bc_lr, batch_size, model_info,
                    algo_name='bc_cont_diff_state', dataset=None, **kwargs):

        super().__init__(obs_dims=obs_dims,action_dims=action_dims,dataset=dataset,
                        batch_size=batch_size, algo_name=algo_name,
                        model_info=model_info, **kwargs)

        
        self.epsilon = kwargs.get('epsilon',1)
        self.epsilon_min = kwargs.get('epsilon_min',0.05)
        self.exploration_decay = kwargs.get('exploration_decay',0.999)
        self._action_space = kwargs['env'].action_space



        self.replay_buffer = ReplayBuffer(obs_dims=obs_dims,
                                            action_dims=action_dims,
                                            batch_size=batch_size,
                                            mem_size=kwargs['mem_size'],
                                            dataset=dataset,
                                            discrete_eps=kwargs['discrete_eps'],
                                            discrete_bins=kwargs['discrete_bins'],
                                            discrete_action=self.dm_suite)

        self.min_state_val = torch.tensor(np.amin(self.replay_buffer.state_memory,0),dtype=torch.float).to(self.device)
        self.max_state_val = torch.tensor(np.amax(self.replay_buffer.state_memory,0),dtype=torch.float).to(self.device)



        self.model = DDPGVectorisedActor(obs_dims=obs_dims,
                                        action_dims=obs_dims,
                                        model_info=model_info,
                                        min_val=self.min_state_val,
                                        max_val=self.max_state_val)


        if self.dm_suite:
            self.actor = DiscreteInverseModel(state_dim=obs_dims,
                                                action_dim=action_dims,
                                                action_bins=kwargs['action_bins'],
                                                model_type='diff_state')
        else:
            self.actor = InverseModel(state_dim=obs_dims,
                                        action_dim=action_dims,
                                        max_action=self.max_action_val,
                                        model_type='diff_state')

        self.actor.load_model(self.env_id,f'1000000_big_eps-{kwargs["discrete_eps"]}_bins-{kwargs["discrete_bins"]}')


        self.replay_buffer.store_offline_data(dataset=dataset,
                                              normalise_state=kwargs['normalise_state'])

        self.model_optimiser = self.optimiser(self.model.parameters(),lr=bc_lr)



        self.move_to(self.device)
        self.batch_size = batch_size
        self.total_it = 0

        if wandb.run is not None:
            wandb.define_metric('train/loss',step_metric='total_step')


    def move_to(self, device):
        super().move_to(device)
        self.model.to(device=device)
        self.actor.to(device=device)

    def choose_action(self, state, cont_diff_state=None, **kwargs):

        state = torch.tensor(state,dtype=torch.float).to(self.device)

        if cont_diff_state is None:
            state = torch.tensor(state,dtype=torch.float).to(self.device).squeeze((0,1))
            cont_diff_state = self.model(state)['action']
        else:
            cont_diff_state = torch.tensor(cont_diff_state,dtype=torch.float).to(self.device)

        action = self.actor(state.squeeze(), cont_diff_state.squeeze())

        if self.dm_suite:
            action = action.argmax(dim=-1)

        action_info = {'action':action}


        return action_info

    def choose_diff_state(self, state, **kwargs):
         
        state = torch.tensor(state,dtype=torch.float).to(self.device).squeeze((0,1))
        cont_diff_state = self.model(state)['action'].detach().cpu().numpy()


        self.replay_buffer.discretise_diff_state(cont_diff_state)

        diff_state = torch.tensor(cont_diff_state,dtype=torch.int).to(self.device)

        return diff_state


    def update_model(self, samples):
        states, next_states, diff_states, actions, rewards, done_batch = samples

        pred_diff_state = self.model(states)['action']


        cont_diff_states = next_states - states
        loss = F.mse_loss(pred_diff_state,cont_diff_states)
        self.log_dict['train/loss'] = loss.item()

        self.model_optimiser.zero_grad()
        loss.backward()
        self.model_optimiser.step()

        return loss

    def learn(self, sample_range=None, **kwargs):

        self.total_it += 1
        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return

        *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                        sample_range=sample_range,
                                                        batch_size=self.batch_size)

        loss = self.update_model(samples)


        if self.total_it%1000000 == 0:
            self.save_model()
        
        return loss


    def save_model(self):

        model_path = self.create_filepath(path='models')

        model_path+= ('-'+str(self.total_it))

        if self.online:
            model_path += '-online'

        self.model_path = model_path

        print(f'Saving models to {model_path}')
        torch.save({'model_state_dict':self.model.state_dict()},
                    model_path)

        return model_path

    def load_model(self, iter_no):

        cf = self.critic_factor
        self.critic_factor = 1

        model_path = self.create_filepath(path='models')
        model_path+= ('-'+str(iter_no))

        self.critic_factor = cf


        print(f"\nLoading models from {model_path}...")
        model_checkpoint = torch.load(model_path)

        self.model.load_state_dict(model_checkpoint['model_state_dict'])


