from abc import ABC, abstractmethod
from collections import deque

import os, pickle, time, wandb
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

from utils.replay_buffer import ReplayBuffer

class BaseAgent(ABC):

    def __init__(self,obs_dims,algo_name,**kwargs):
        self.algo_name = algo_name
        if kwargs.get('w_BC',False):
            if self.algo_name != 'combined':
                self.algo_name += '_BC'
            self.bc_factor = kwargs.get('bc_factor',0)
            self.decay_factor = kwargs.get('decay_factor',0)
            self.min_bc_factor = min(self.bc_factor,kwargs.get('min_bc_factor',0))
        else:
            self.bc_factor = 0
        

        self.log_dict = {}
        self.algo_type = kwargs.get('algo_type')
        self.env_id = kwargs.get('env_id')
        self.gamma = kwargs.get('gamma',0.99)
        self.tau = kwargs.get('tau',0)
        self.device = kwargs['device']
        self.optimiser = getattr(optim, kwargs['optimiser'])
        self.seed = kwargs['seed']
        self.dm_suite = kwargs['dm_suite']
        if self.dm_suite:
            self.action_bins = kwargs['action_bins']
        else:
            self.min_action_val = torch.tensor(kwargs['min_val'],device=self.device,dtype=torch.float)
            self.max_action_val = torch.tensor(kwargs['max_val'],device=self.device,dtype=torch.float)
            self.action_dim = self.min_action_val.shape[0]

        self.rng = kwargs.get('rng')
        self.swap_critics = kwargs.get('swap_critics')
        self.ensemble_num = kwargs.get('ensemble_num')
        self.critic_factor = kwargs.get('critic_factor',1)
        self.policy_update_freq = kwargs.get('policy_update_freq')
        self.redQ = kwargs.get('redQ')
        self.model_save = kwargs.get('save_model',False)

        self.update_ratio = kwargs.get('update_ratio',1)
        self.n_steps = kwargs.get('n_steps',1)


        if kwargs['algo_type'] == 'online':
            self.online = True
        elif kwargs['algo_type'] == 'offline':
            self.online = False

        if wandb.run is not None:
            wandb.define_metric('total_step',hidden=True)
            wandb.define_metric('eval/d4rl_normalised_score',step_metric='total_step')
            wandb.define_metric('eval/average state diff prediction error',step_metric='total_step')

        self.wandb_log_iter = kwargs['wandb_log_iter']

    def move_to(self, device):
        self.replay_buffer.to(device=device)


    def store_transition(self, state, next_state, action, reward, done):
        self.replay_buffer.store_transition(state,next_state,action,reward,done)

    def sample(self, **kwargs):
        return self.replay_buffer.sample(**kwargs)

    def _evaluate_performance(self, env, iteration, config_dict, **kwargs):
        
        obs = env.reset()[0]
        if config_dict['normalise_state']:
            obs = (obs- self.replay_buffer.mean)/self.replay_buffer.std
        else:
            obs = obs[np.newaxis]
        

        wandb_dict = kwargs['wandb_dict']

        done = False
        total_reward = 0
        var_diff_list = deque(maxlen=5)
        episode_var = 0
        goal = None

        diff_state_error = 0
        iss_act_error = 0
        step = 0


        max_std = 0
        true_std = 0
        while not done:
            step +=1

            obs = obs[np.newaxis,np.newaxis]

            act = self.choose_action(obs, deterministic=True, transform=True)['action']

            act = act.cpu().detach().numpy()

            next_obs, reward, terminate, trunc, info = env.step(act.squeeze())


            if config_dict['normalise_state']:
                next_obs = (next_obs-self.replay_buffer.mean)/self.replay_buffer.std 
            


            true_diff_obs = next_obs - obs
            self.replay_buffer.discretise_diff_state(true_diff_obs)


            if self.algo_name not in ['td3_n','sac_n','combined'] and 'action' not in self.algo_name:
                diff_obs = self.choose_diff_state(obs)
                np_diff_obs = diff_obs.detach().cpu().numpy()

                diff_state_error += np.sum(np.abs(np_diff_obs-true_diff_obs))

                iss_act = self.choose_action(obs,true_diff_obs,deterministic=True)['action'].detach().cpu().numpy()

                iss_act_error += np.linalg.norm(act.squeeze()-iss_act,np.inf)

            done = terminate | trunc
            total_reward += reward
            obs = next_obs


        avg_diff_error = diff_state_error/step
        avg_iss_act_error = iss_act_error/step

        return total_reward, avg_diff_error, avg_iss_act_error


    def evaluate_performance(self, config_dict, iteration):


        env = config_dict['env']

        if 'ant' in config_dict['env_id']:
            env.reset()
        else:
            env.reset(seed=self.seed)


        agent_return_list = []
        agent_diff_error_list = []
        agent_iss_act_error_list = []
        
        total_diff_error = 0
        for i in range(config_dict['num_evals']):

            wandb_dict = None

            avg_results, avg_diff_error, avg_iss_act_error = self._evaluate_performance(env=env,
                                                                                    iteration=iteration,
                                                                                    config_dict=config_dict,
                                                                                    wandb_dict=wandb_dict)

            total_diff_error += avg_diff_error
        
            agent_return_list.append(avg_results)
            agent_diff_error_list.append(avg_diff_error)
            agent_iss_act_error_list.append(avg_iss_act_error)

        


        avg_return = sum(agent_return_list)/config_dict['num_evals']
        norm_avg_return = 100*env.get_normalized_score(avg_return)

        norm_return_list = 100*env.get_normalized_score(np.array(agent_return_list))


        if wandb.run is not None:
            avg_diff_error = sum(agent_diff_error_list)/len(agent_diff_error_list)
            avg_iss_act_error = sum(agent_iss_act_error_list)/len(agent_iss_act_error_list)

            eval_dict = {}
            eval_dict['eval/d4rl_normalised_score'] = norm_avg_return
            eval_dict['eval/d4rl_unnormalised_score'] = avg_return
            if avg_diff_error != 0:
                eval_dict['eval/average state diff prediction error'] = avg_diff_error
                eval_dict['eval/average iss action prediction error'] = avg_iss_act_error
            eval_dict['total_step'] = self.total_it*self.update_ratio
            wandb.log(eval_dict)

        if not self.online:

            output_str = 10*'-' + f'Agent using {self.algo_name} {self.algo_type}' +\
                    f' averaged over {config_dict["num_evals"]}'+ \
                    f' episodes with {self.env_id} dataset' + 10*'-'
            print(f'\n{output_str}')
            print(f'Unnormalised return: {avg_return}')
            print(f'Avg normalised return: {norm_avg_return}')
            print(f'Min normalised return: {norm_return_list.min()}')
            print(f'Max normalised return: {norm_return_list.max()}')
            print(f'Std normalised return: {norm_return_list.std()}')

            print(f'\nAlgo has an ensemble of {self.ensemble_num} actors and {self.critic_factor} critics per actor')


        if self.online:

            env_name_str = f'env name: {config_dict["env_id"]}'
            env_steps_str = f' total env steps: {iteration},'
            grad_steps_str = f' total grad steps: {self.total_it},'
            score_str = f' raw_score: {round(avg_return,5)}, norm score: {round(norm_avg_return,5)}'
            print(env_name_str+env_steps_str+grad_steps_str+score_str)


        if avg_diff_error != 0:
            print(f'individual returns')
            print([round(x,2) for x in norm_return_list])
            print(f'average diff error per step: {sum(agent_diff_error_list)/len(agent_diff_error_list)}')
            print([round(x,2) for x in agent_diff_error_list])
            print(f'average iss act error per step: {sum(agent_iss_act_error_list)/len(agent_iss_act_error_list)}')
            print([round(x,2) for x in agent_iss_act_error_list])


        return avg_return, norm_avg_return, avg_diff_error 

    @abstractmethod
    def choose_action(self, state):
        pass

    @abstractmethod
    def learn(self):
        pass


    def train_online(self, config_dict):
        env = config_dict['env']

        total_steps = 0
        ep_num = 0 

        if 'ant' in config_dict['env_id']:
            env.reset()
        else:
            env.reset(seed=self.seed)

        n_step_buffer = []


       #########################################Store data in replay buffer##################################################
        print('Adding random data')
        while len(self.replay_buffer)<config_dict['burn_in_steps']:
            obs = env.reset()[0]
            if config_dict['normalise_state']:
                obs = (obs- self.replay_buffer.mean)/self.replay_buffer.std
            else:
                obs = obs[np.newaxis]

            done = False


            while not done:
                obs = obs[np.newaxis,np.newaxis]
                act = env.action_space.sample()
                next_obs, reward, terminate, trunc, info = env.step(act.squeeze())
                done = terminate | trunc

                if config_dict['normalise_state']:
                    next_obs = (next_obs - self.replay_buffer.mean)/self.replay_buffer.std

                if 'ant' in config_dict['env_id']:
                    reward = 4*(reward - 0.5)
                

                
                if self.n_steps == 1:
                    self.replay_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                else:
                    n_step_buffer.append((obs,act,reward))
                    if len(n_step_buffer) == self.n_steps:
                        state_0, action_0, _ =  n_step_buffer[0]
                        disc_returns = np.sum([r * self.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                        self.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc, 
                                                            true_next_state=n_step_buffer[1][0])
                        n_step_buffer.pop(0)

                if done:
                    while n_step_buffer:
                        state_0, action_0, _ = n_step_buffer[0]
                        disc_returns = np.sum([r * self.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                        try:
                            self.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc,
                                                                true_next_state=n_step_buffer[1][0])
                        except IndexError:
                            self.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc)

                        n_step_buffer.pop(0)

                obs = next_obs

        print('Data adding process complete...')
       ############################################################################################################### 


        self.evaluate_performance(config_dict,total_steps)
        s = time.time()

        while total_steps <config_dict['online_steps']:
            obs = env.reset()[0]
            if config_dict['normalise_state']:
                obs = (obs- self.replay_buffer.mean)/self.replay_buffer.std
            else:
                obs = obs[np.newaxis]

            done = False
            total_reward = 0
            ep_num += 1



            while not done: 
                total_steps+=1
                obs = obs[np.newaxis,np.newaxis]

                act = self.choose_action(obs)['action']

                try:
                    act = act.cpu().detach().numpy()
                except AttributeError:
                    pass

                if not config_dict['dm_suite']:
                    act += np.random.normal(scale=0.1,size=act.shape)
            
                next_obs, reward, terminate, trunc, info = env.step(act.squeeze())
                done = terminate | trunc
                total_reward += reward
                if config_dict['normalise_state']:
                    next_obs = (next_obs - self.replay_buffer.mean)/self.replay_buffer.std



                if self.n_steps == 1:
                    self.replay_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                else:
                    n_step_buffer.append((obs,act,reward))
                    if len(n_step_buffer) == self.n_steps:
                        state_0, action_0, _ =  n_step_buffer[0]
                        disc_returns = np.sum([r * self.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                        self.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc, 
                                                            true_next_state=n_step_buffer[1][0])
                        n_step_buffer.pop(0)

                if done:
                    while n_step_buffer:
                        state_0, action_0, _ = n_step_buffer[0]
                        disc_returns = np.sum([r * self.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                        try:
                            self.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc,
                                                                true_next_state=n_step_buffer[1][0])
                        except IndexError:
                            self.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc)

                        n_step_buffer.pop(0)


                if total_steps%self.update_ratio == 0:
                    loss = self.learn()
                    if wandb.run is not None and total_steps%self.wandb_log_iter == 0 :
                        self.log_dict['total_step'] = self.total_it*self.update_ratio
                        wandb.log(self.log_dict)

                obs = next_obs

                if total_steps%config_dict['eval_counter'] == 0:
                    print(loss)
                    self.evaluate_performance(config_dict,total_steps)
                    print(time.time()-s)
                    s = time.time()


    def train_offline(self, config_dict):

        loss_hist = []
        env = config_dict['env']

        s = time.time()



        i = 0

        while i < config_dict['num_env_steps']:

            loss = self.learn(dep_targ=config_dict['dep_targ'])
            if wandb.run is not None and self.total_it%self.wandb_log_iter == 0 :
                self.log_dict['total_step'] = self.total_it
                wandb.log(self.log_dict)


            if (i+1)%config_dict['eval_counter'] == 0:
                
                if loss is not None:
                    print(loss)

                print(f'\nIteration: {i+1}')
                
                if not ('qss_n' in self.algo_name):
                    self.evaluate_performance(config_dict,iteration=i)


                print(time.time()-s)
                s = time.time()

            i+=1

        wandb.finish()



    def create_filepath(self, path, file_ext=None, path_list=None):
        
        if path_list is None:
            path_list = [path, self.algo_name, self.env_id]


        if 'td3' in self.algo_name:
            file_name = f'actor_{self.ensemble_num}-critic_{self.ensemble_num*self.critic_factor}'
        else:
            file_name = f'critic_{self.ensemble_num*self.critic_factor}'
            file_name += f'-eps_{self.replay_buffer.discrete_eps}-bins_{self.replay_buffer.discrete_bins}'

        file_name += f'-seed_{self.seed}'

        if file_ext is not None:
            file_name += file_ext

        file_path = ''
        for path in path_list:
            file_path = os.path.join(file_path,path)
            if not os.path.exists(file_path):
                os.makedirs(file_path)

        file_path = os.path.join(file_path,file_name)
        return file_path
		


class ContinuousBaseAgent(BaseAgent):

    def __init__(self, obs_dims, action_dims, batch_size, algo_name, gamma=None, 
                tau=None, mem_size=None, dataset=None, **kwargs):

        super().__init__(obs_dims=obs_dims, gamma=gamma, tau=tau, algo_name=algo_name, **kwargs)

        self.replay_buffer = ReplayBuffer(obs_dims=obs_dims,
                                        action_dims=action_dims,
                                        mem_size=mem_size,
                                        batch_size=batch_size,
                                        dataset=dataset,
                                        normalise_state=kwargs['normalise_state'],
                                        discrete_eps=kwargs['discrete_eps'],
                                        discrete_bins=kwargs['discrete_bins'],
                                        discrete_action=self.dm_suite)



class BaseActorCritic(ContinuousBaseAgent):

    def __init__(self, obs_dims, action_dims, batch_size, algo_name, gamma=None, tau=None, mem_size=None, dataset=None, **kwargs):

        self.critic_lr = kwargs['critic_lr']
        self.actor_lr = kwargs['actor_lr']

        super().__init__(obs_dims=obs_dims, action_dims=action_dims, gamma=gamma, tau=tau,
                         algo_name=algo_name, mem_size=mem_size, batch_size=batch_size, dataset=dataset,
                         **kwargs)

    def _calc_critic_value(self,critic_values,log_probs=None,done_batch=None):

        if self.critic_factor != 1:
            if self.redQ:
                critic_ensemble = self.rng.choice(self.critic_factor,2,replace=False)
                critic_values = critic_values[:,critic_ensemble]
            critic_values = torch.min(critic_values,dim=1).values

        if done_batch is not None:
            critic_values[done_batch] = 0
            
        if getattr(self,'alpha',None) is not None and log_probs is not None:
            critic_values =  critic_values - self.alpha.detach()*log_probs

        return critic_values

    @abstractmethod
    def update_critic(self):
        pass

    @abstractmethod
    def update_actor(self):
        pass

class DiscreteBaseAgent(BaseAgent):

    def __init__(self, obs_dims, n_actions, gamma, tau, mem_size, batch_size, algo_name, dataset=None, **kwargs):

        super().__init__(obs_dims=obs_dims, gamma=gamma, tau=tau, algo_name=algo_name,**kwargs)

        self.replay_buffer = ReplayBuffer(mem_size=mem_size,
                                batch_size=batch_size,
                                obs_dims=obs_dims,
                                action_dims=n_actions,
                                normalise_state=kwargs['normalise_state'],
                                discrete_action=True)

