import torch
import torch.optim as optim
import torch.nn.functional as F

from utils.base_agent import ContinuousBaseAgent
from utils.vectorised_networks import DDPGVectorisedActor
#from utils.decomp_networks import DecoupledQNetwork


class Agent(ContinuousBaseAgent):
    def __init__(self, obs_dims, action_dims, bc_lr, batch_size, model_info,
                    algo_name='bc_action', dataset=None, **kwargs):

        super().__init__(obs_dims=obs_dims,action_dims=action_dims,dataset=dataset,
                        batch_size=batch_size, algo_name=algo_name,
                        model_info=model_info, **kwargs)

        
        ##acts the same as BC model with additional tanh activation function to cap action
        if self.dm_suite:
           #self.model = DiscreteVectorisedActor(obs_dims=obs_dims,
           #                                action_dims=action_dims,
           #                                model_info=model_info,
           #                                n_actions=kwargs['action_bins'])
            self.model = DecoupledQNetwork(state_dim=obs_dims,
                                            hidden_dim=256,
                                            num_states=kwargs['action_bins'],
                                            num_heads=action_dims)

        else:
            self.model = DDPGVectorisedActor(obs_dims=obs_dims,
                                            action_dims=action_dims,
                                            model_info=model_info,
                                            min_val=self.min_action_val,
                                            max_val=self.max_action_val)


        self.model_optimiser = self.optimiser(self.model.parameters(),lr=bc_lr)

        if not self.online:
            self.replay_buffer.store_offline_data(dataset=dataset,
                                                  normalise_state=kwargs['normalise_state'])

        self.move_to(self.device)
        self.batch_size = batch_size
        self.total_it = 0


    def move_to(self, device):
        super().move_to(device)
        self.model.to(device=device)

    def choose_action(self, state, **kwargs):
         
        state = torch.tensor(state,dtype=torch.float).to(self.device)
        action_info = self.model(state)
        if self.dm_suite:
            action_info['action'] = action_info['action'].argmax(dim=-1)

        return action_info
    
    def update_model(self, samples):
        states, next_states, diff_states, actions, rewards, done_batch = samples

        model_actions = self.model(states)['action']

        if self.dm_suite:
            loss = F.cross_entropy(model_actions.flatten(0,1),actions.flatten(0,1))
        else:
            loss = F.mse_loss(model_actions,actions)

        self.log_dict['train/loss'] = loss.item()

        self.model_optimiser.zero_grad()
        loss.backward()
        self.model_optimiser.step()

        return loss

    def learn(self, sample_range=None, **kwargs):

        self.total_it += 1
        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return

        loss = None

        *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                        sample_range=sample_range,
                                                        batch_size=self.batch_size)
        
        if not self.online:
            loss = self.update_model(samples)

        if self.total_it%100000 == 0 and self.model_save:
            self.save_model()

        return loss

    def save_model(self):

        model_path = self.create_filepath(path='models')

        model_path+= ('-'+str(self.total_it))

        self.model_path = model_path

        print(f'Saving models to {model_path}')
        torch.save({'model_state_dict':self.model.state_dict()},
                    model_path)

        return model_path

    def load_model(self, iter_no):

        cf = self.critic_factor
        self.critic_factor = 1

        model_path = self.create_filepath(path='models')
        model_path+= ('-'+str(iter_no))

        self.critic_factor = cf


        print(f"\nLoading models from {model_path}...")
        model_checkpoint = torch.load(model_path)

        self.model.load_state_dict(model_checkpoint['model_state_dict'])





