import wandb

import torch
import torch.optim as optim
import torch.nn.functional as F

from torch.distributions.categorical import Categorical

from copy import deepcopy
import numpy as np
import os, time

from utils.decomp_networks import EnsembleDecoupledQNetwork
from utils.misc import soft_update, Scalar
from utils.base_agent import BaseAgent
from utils.replay_buffer import ReplayBuffer

from utils.vectorised_networks import Revalued

from rl_algos.single_agent.TD3.agent import Agent as BaselineTD3Agent
from inverse_model import InverseModel, DiscreteInverseModel

class Agent(BaseAgent):

    def __init__(self, obs_dims, action_dims, critic_lr, actor_lr, gamma, tau, mem_size,
                batch_size, critic_factor, model_info=None,dataset=None,
                algo_name='revalued_diff_state', load_actor=True,
                **kwargs):

        super().__init__(obs_dims=obs_dims,gamma=gamma,tau=tau,
                        algo_name=algo_name, critic_factor=critic_factor, 
                        **kwargs)
        self.n_steps=1

        self.epsilon = kwargs.get('epsilon',1)
        self.epsilon_min = kwargs.get('epsilon_min',0.05)
        self.exploration_decay = kwargs.get('exploration_decay',0.999)

        self._action_space = kwargs['env'].action_space
        self.noise_clip = kwargs['noise_clip']
        self.policy_noise_std = kwargs['policy_noise_std']
        self.sample_type = kwargs['sample_type']

        self.bc_alpha = kwargs['bc_alpha']


        hidden_dim = 512

        self.critic = EnsembleDecoupledQNetwork(state_dim=obs_dims,
                                        hidden_dim=hidden_dim,
                                        num_actions=kwargs['discrete_bins'],  ###number of bins per sub_action
                                        num_heads=obs_dims, ###number of sub_actions
                                        ensemble_size=critic_factor)
			

        self.target_critic = deepcopy(self.critic)
        self.critic_optimiser = self.optimiser(self.critic.parameters(),lr=critic_lr)

        if kwargs['dm_suite']:
            self.actor = DiscreteInverseModel(state_dim=obs_dims,
                                                action_dim=action_dims,
                                                action_bins=kwargs['action_bins'],
                                                model_type='diff_state')
        else:
            self.actor = InverseModel(state_dim=obs_dims,
                                        action_dim=action_dims,
                                        max_action=self.max_action_val,
                                        model_type='diff_state')


        self.actor_optimiser = self.optimiser(self.actor.parameters(),lr=1e-3)


        self.replay_buffer = ReplayBuffer(batch_size=batch_size,
                                        obs_dims=obs_dims,
                                        action_dims=action_dims,
                                        dataset=dataset,
                                        mem_size=mem_size,
                                        discrete_eps=kwargs['discrete_eps'],
                                        discrete_bins=kwargs['discrete_bins'],
                                        discrete_action=self.dm_suite,
                                        )

        if self.online:
            self.load_model(3000000)
            self.dataset_buffer = ReplayBuffer(batch_size=batch_size,
                                            obs_dims=obs_dims,
                                            action_dims=action_dims,
                                            dataset=dataset,
                                            mem_size=mem_size,
                                            discrete_eps=kwargs['discrete_eps'],
                                            discrete_bins=kwargs['discrete_bins'],
                                            sample_returns=True
                                            )

            self.dataset_buffer.to(self.device)

            self.dataset_buffer.store_offline_data(dataset=dataset,
                                                normalise_state=kwargs['normalise_state'],
                                                env_id=self.env_id)
        else:
            self.actor.load_model(self.env_id,f'1000000_big_eps-{kwargs["discrete_eps"]}_bins-{kwargs["discrete_bins"]}')


            self.replay_buffer.store_offline_data(dataset=dataset,
                                            normalise_state=kwargs['normalise_state'])

        self.log_alpha_prime = Scalar(1.0)
        self.alpha_prime_optimiser = self.optimiser(self.log_alpha_prime.parameters(),lr=critic_lr)

        self.move_to()
        self.batch_size = batch_size

        self.total_it = 0

        if wandb.run is not None:
            wandb.define_metric('train/critic_loss',step_metric='total_step')
            wandb.define_metric('train/critic_values',step_metric='total_step')




    def move_to(self):
        super().move_to(self.device)
        self.critic.to(device=self.device)
        self.target_critic.to(device=self.device)
        self.actor.to(device=self.device)
        self.log_alpha_prime.to(device=self.device)

    def choose_diff_state(self, state, **kwargs):

        state = torch.tensor(state,dtype=torch.float).to(self.device).squeeze(0)
        diff_state_vals = self.critic(state)

        diff_state = diff_state_vals.mean(dim=1).argmax(dim=-1)

        return diff_state

    def choose_action(self, state, diff_state=None, **kwargs):

        state = torch.tensor(state,dtype=torch.float).to(self.device)

        if diff_state is None:
            diff_state = self.choose_diff_state(state)
        else:
            diff_state  = torch.tensor(diff_state,dtype=torch.float).to(self.device)

        action = self.actor(state.squeeze(), diff_state.squeeze())

        if self.dm_suite:
            action = action.argmax(dim=-1)

        action_info = {'action':action}


        return action_info


    def update_actor(self, samples, iter_no=None, agent=None):

        if self.dm_suite:
            loss_func = F.cross_entropy
        else:
            loss_func = F.l1_loss

        if iter_no is None:
            iter_no = self.total_it


        states, _, diff_states, actions, _, _  = samples

        action_pred = self.actor(states,diff_states)

        if self.dm_suite:
            actor_loss = F.cross_entropy(action_pred.flatten(0,1), actions.flatten(0,1))
        else:
            actor_loss = F.l1_loss(action_pred, actions)

        online_buffer_loss = actor_loss.item()

        if agent is not None:
            
            *dataset_samples, batch_idx = self.dataset_buffer.sample(rng=self.rng,
                                                            batch_size=self.batch_size)

            d_states, _, d_diff_states, _, _, _, d_returns  = dataset_samples


            if agent is not None:

                with torch.no_grad():
                    d_approx_actions = agent.choose_action(d_states.squeeze(),deterministic=True)['action']
                    
                    if self.dm_suite:
                        d_approx_actions = d_approx_actions.to(device=self.device)
                        actor_q_values = agent.critic(d_states.squeeze()).gather(-1,
                                                d_approx_actions.unsqueeze(1).unsqueeze(1)).mean(dim=-1)
                    else:
                        actor_q_values = agent.critic(d_states.squeeze(),d_approx_actions).min(dim=1).values

                    utility_values = self.critic(d_states)
                    q_values = utility_values.mean(dim=1).max(dim=-1).values.mean(dim=1)
                    offline_weights = (actor_q_values.squeeze() - q_values)/q_values.abs()
                    d_weights = (actor_q_values.squeeze() - d_returns)/d_returns.abs()

                    weights = 0.5*(offline_weights+d_weights)
                    weights = torch.clamp(weights,-1,1)


                d_action_pred = self.actor(d_states,d_diff_states)
                if self.dm_suite:
                    d_actor_loss = F.cross_entropy(d_action_pred.flatten(0,1),d_approx_actions.flatten(0,1),
                                                    reduction='none')
                else:
                    d_actor_loss = F.l1_loss(d_action_pred,d_approx_actions,reduction='none')
                actor_loss += d_actor_loss.mean()


        self.actor_optimiser.zero_grad()
        actor_loss.backward()
        self.actor_optimiser.step()

        if self.total_it%self.wandb_log_iter == 0:
            self.log_dict['train/inverse_loss'] = actor_loss.item()
            self.log_dict['train/online_buffer_inverse_loss'] = online_buffer_loss
            self.log_dict['train/offline_buffer_inverse_loss'] = d_actor_loss.mean().item()

        return actor_loss



    def update_critic(self, samples, iter_no=None):

        critic_factor = self.critic_factor
        if iter_no is None:
            iter_no = self.total_it
        

        if self.online:
            *dataset_samples, batch_idx = self.dataset_buffer.sample(rng=self.rng,
                                                            batch_size=self.batch_size)

            samples = self.replay_buffer.stitch_samples(samples,dataset_samples)

        states, next_states, diff_states, _, rewards, done_batch = samples

        if not self.online:
            rewards = rewards.permute((1,0))
            done_batch = done_batch.permute((1,0))


        utility_values = self.critic(states)


        
        rep_diff_states = diff_states.unsqueeze(1).repeat(1,critic_factor,1)

        selected_utility_values = utility_values.gather(-1,rep_diff_states.unsqueeze(-1))
        q_values = selected_utility_values.mean(dim=2).squeeze()


        if self.sample_type == 'double_q':
            double_idx = self.rng.choice(critic_factor,size=(1,),replace=False).item()
            q_values = q_values[:,double_idx]




        with torch.no_grad():

            target_values = self.target_critic(next_states)
            roll_idx = np.roll(np.arange(critic_factor),1)
            soft_diff_state_prob = utility_values.softmax(dim=-1)
            multi_dist = Categorical(soft_diff_state_prob)
            pred_next_diff_state = multi_dist.sample()
            pred_next_diff_state = pred_next_diff_state[:,roll_idx]
            target_values = target_values.gather(-1,pred_next_diff_state.unsqueeze(-1)).mean(dim=1)
            target_q_vals = target_values.mean(dim=1)
            target_q_vals[done_batch] = 0
            est_critic_val = rewards + (self.gamma**self.n_steps)*target_q_vals
            est_critic_val = est_critic_val.squeeze()


        logsumexp = torch.logsumexp(utility_values,dim=-1)
        p = (logsumexp-selected_utility_values.squeeze()).mean()


        if self.dm_suite:
            critic_loss = F.huber_loss(q_values,est_critic_val)   + self.bc_alpha*p 
        else:
            critic_loss = F.mse_loss(q_values,est_critic_val)   + self.bc_alpha*p 



        if self.total_it%self.wandb_log_iter == 0:
            self.log_dict['train/critic_loss'] = critic_loss.item()
            self.log_dict['train/critic_values'] = q_values.mean().item()
         



        self.critic_optimiser.zero_grad()
        critic_loss.backward()
        if self.dm_suite:
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 40)
        self.critic_optimiser.step()




        return critic_loss


    def learn(self, samples=None, **kwargs):

        self.total_it +=1
        critic_loss = None

        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return

        if samples is None:
            *samples, batch_idx = self.replay_buffer.sample(rng=self.rng,
                                                            batch_size=self.batch_size)


        if self.online:
            actor_loss = self.update_actor(samples)

            if self.total_it == 1000000:
                f_name = f'1000000_big_eps-{self.replay_buffer.discrete_eps}_bins-{self.replay_buffer.discrete_bins}-online'
                self.actor.train_inverse(self,save_num=f_name)


        critic_loss = self.update_critic(samples)
        soft_update(self.target_critic, self.critic, tau=self.tau)

        if self.total_it%1000000 == 0 and self.model_save:
            self.save_model()

        return critic_loss

    def save_model(self):

        model_path = self.create_filepath(path='models')

        model_path += ('-'+self.sample_type)
        model_path += ('-'+str(self.total_it))


        if self.online:
            model_path += '-online'

        

        self.model_path = model_path

        print(f'Saving models to {model_path}')
        torch.save({'critic_state_dict':self.critic.state_dict(),
                    'target_critic_state_dict':self.target_critic.state_dict(),},
                   model_path)
        return model_path

    def load_model(self, iter_no, model_path=None, evaluate=False, online=False):
        if model_path is None:
            model_path = self.create_filepath(path='models')
        print(f"\nLoading models from {model_path}...")

        model_path += ('-'+self.sample_type)
        model_path += ('-'+str(iter_no))


        model_checkpoint = torch.load(model_path, map_location=self.device)

        self.critic.load_state_dict(model_checkpoint['critic_state_dict'])
        self.target_critic.load_state_dict(model_checkpoint['target_critic_state_dict'])

        if evaluate:
            self.actor.eval()
            self.critic.eval()
            self.target_critic.eval()
        else:
            self.actor.train()
            self.critic.train()
            self.target_critic.train()
