from abc import ABC, abstractmethod
from collections import deque

import os, pickle, time, wandb
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F

from utils.replay_buffer import ReplayBuffer

class BaseAgent(ABC):

    def __init__(self,obs_dims,algo_name,**kwargs):
        self.algo_name = algo_name
        if kwargs.get('w_BC',False):
            if self.algo_name != 'combined':
                self.algo_name += '_BC'
            self.bc_factor = kwargs.get('bc_factor',0)
            self.decay_factor = kwargs.get('decay_factor',0)
            self.min_bc_factor = min(self.bc_factor,kwargs.get('min_bc_factor',0))
        else:
            self.bc_factor = 0
        

        self.log_dict = {}
        self.algo_type = kwargs.get('algo_type')
        self.env_id = kwargs.get('env_id')
        self.gamma = kwargs.get('gamma',0.99)
        self.tau = kwargs.get('tau',0)
        self.device = kwargs['device']
        self.optimiser = getattr(optim, kwargs['optimiser'])
        self.seed = kwargs['seed']
        self.dm_suite = kwargs['dm_suite']
        self.is_continuous = kwargs['is_continuous']
        print('is algo continuous:',self.is_continuous)
        if not self.is_continuous:
            self.action_bins = kwargs['action_bins']
            self.action_dim = kwargs['action_dims']
        else:
            self.min_action_val = torch.tensor(kwargs['min_val'],device=self.device,dtype=torch.float)
            self.max_action_val = torch.tensor(kwargs['max_val'],device=self.device,dtype=torch.float)
            self.action_dim = self.min_action_val.shape[0]

        self.rng = kwargs.get('rng')
        self.swap_critics = kwargs.get('swap_critics')
        self.ensemble_num = kwargs.get('ensemble_num')
        self.critic_factor = kwargs.get('critic_factor',1)
        self.policy_update_freq = kwargs.get('policy_update_freq')
        self.redQ = kwargs.get('redQ')
        self.utd_ratio = kwargs.get('utd_ratio',1)
        self.model_save = kwargs.get('save_model',False)

        self.update_ratio = kwargs.get('update_ratio',1)


        if kwargs['algo_type'] == 'online':
            self.online = True
        elif kwargs['algo_type'] == 'offline':
            self.online = False

        self.n_steps = kwargs.get('n_steps',1) if self.online else 1

        if wandb.run is not None:
            wandb.define_metric('total_step',hidden=True)
            wandb.define_metric('eval/d4rl_normalised_score',step_metric='total_step')

        self.wandb_log_iter = kwargs['wandb_log_iter']

    def move_to(self, device):
        self.replay_buffer.to(device=device)


    def store_transition(self, state, next_state, action, reward, terminate, trunc):
        self.replay_buffer.store_transition(state,next_state,action,reward, terminate, trunc)

    def sample(self, **kwargs):
        return self.replay_buffer.sample(**kwargs)

    def _evaluate_performance(self, env, iteration, config_dict, **kwargs):
        

        obs = env.reset()[0]
        if config_dict['normalise_state']:
            obs = (obs- self.replay_buffer.mean)/self.replay_buffer.std
        

        wandb_dict = kwargs['wandb_dict']

        done = False
        total_reward = 0

        while not done:

            act = self.choose_action(obs, deterministic=True)['action']

            act = act.cpu().detach().numpy()

           #next_obs, reward, terminate, trunc, info = env.step(act.squeeze()[np.newaxis])
            obs, reward, terminate, trunc, info = env.step(act.squeeze())

            if config_dict['normalise_state']:
                obs = (obs-self.replay_buffer.mean)/self.replay_buffer.std 
            
            done = terminate | trunc
            total_reward += reward

        return total_reward


    def evaluate_performance(self, config_dict, iteration):


        env = config_dict['test_env']
        env.reset(seed=self.seed)
       #env.reset(seed=(iteration+self.seed))
       #env.reset()


        agent_return_list = []
        
        for i in range(config_dict['num_evals']):

            wandb_dict = None

            avg_results = self._evaluate_performance(env=env,
                                                    iteration=iteration,
                                                    config_dict=config_dict,
                                                    wandb_dict=wandb_dict)

        
            agent_return_list.append(avg_results)

        


        avg_return = sum(agent_return_list)/config_dict['num_evals']

        try:
            norm_avg_return = 100*env.get_normalized_score(avg_return)
            norm_return_list = 100*env.get_normalized_score(np.array(agent_return_list))
        except AttributeError:
            norm_avg_return = 0 
            norm_return_list = 0 


        if wandb.run is not None:

            eval_dict = {}
            eval_dict['eval/d4rl_normalised_score'] = norm_avg_return
            eval_dict['eval/d4rl_unnormalised_score'] = avg_return
            eval_dict['total_step'] = self.total_it
            wandb.log(eval_dict)

        if not self.online:
           #min_return = min(agent_return_list)
           #max_return = max(agent_return_list)
           #std_return = np.std(agent_return_list)

            output_str = 10*'-' + f'Agent using {self.algo_name} {self.algo_type}' +\
                    f' averaged over {config_dict["num_evals"]}'+ \
                    f' episodes with {self.env_id} dataset' + 10*'-'
            print(f'\n{output_str}')
            print(f'Unnormalised return: {avg_return}')
            print(f'Avg normalised return: {norm_avg_return}')
           #print(f'All normalised returns: {np.around(100*env.get_normalized_score(np.asarray(agent_return_list)),2)}')
            print(f'Min normalised return: {norm_return_list.min()}')
            print(f'Max normalised return: {norm_return_list.max()}')
            print(f'Std normalised return: {norm_return_list.std()}')

            print(f'\nAlgo has an ensemble of {self.ensemble_num} actors and {self.critic_factor} critics per actor')


        if self.online:

            env_name_str = f'env name: {config_dict["env_id"]}'
            env_steps_str = f' total env steps: {iteration},'
            grad_steps_str = f' total grad steps: {self.total_it},'
            score_str = f' raw_score: {round(avg_return,5)}, norm score: {round(norm_avg_return,5)}'
            print(env_name_str+env_steps_str+grad_steps_str+score_str)
            print(agent_return_list)




        return avg_return

    @abstractmethod
    def choose_action(self, state):
        pass

    @abstractmethod
    def learn(self):
        pass


    def train_online(self, config_dict):
        env = config_dict['env']

        total_steps = 0
        ep_num = 0 

        if 'ant' in config_dict['env_id']:
            env.reset()
        else:
            env.reset(seed=self.seed)


       #########################################Store data in replay buffer##################################################
        print('Adding random data')
        while len(self.replay_buffer)<config_dict['burn_in_steps']:
            obs = env.reset()[0]
            if config_dict['normalise_state']:
                obs = (obs- self.replay_buffer.mean)/self.replay_buffer.std

            done = False

            n_step_buffer = []

            while not done:
                act = env.action_space.sample()
               #next_obs, reward, terminate, trunc, info = env.step(act.squeeze()[np.newaxis])
                next_obs, reward, terminate, trunc, info = env.step(act.squeeze())
                done = terminate | trunc

                if config_dict['normalise_state']:
                    next_obs = (next_obs - self.replay_buffer.mean)/self.replay_buffer.std

                if self.transform_reward:
                   #reward = reward - 1
                    reward = reward + 1

                
                if self.n_steps == 1:
                    self.store_transition(obs,next_obs,act,reward,terminate,trunc)
                else:
                    ##n_step buffer here otherwise it will keep accumulating for n=1
                    n_step_buffer.append((obs,act,reward))
                    if len(n_step_buffer) == self.n_steps:
                        state_0, action_0, _ =  n_step_buffer[0]
                        disc_returns = np.sum([r * self.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                        self.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc)
                        n_step_buffer.pop(0)

                if terminate:
                    while n_step_buffer:
                        state_0, action_0, _ = n_step_buffer[0]
                        disc_returns = np.sum([r * self.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                        self.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc)

                        n_step_buffer.pop(0)

                obs = next_obs

        print('Data adding process complete...')
       ############################################################################################################### 


        self.evaluate_performance(config_dict,total_steps)

        v = time.time()
        while total_steps <config_dict['online_steps']:
            obs = env.reset()[0]
            if config_dict['normalise_state']:
                obs = (obs- self.replay_buffer.mean)/self.replay_buffer.std

            done = False
            total_reward = 0
            ep_num += 1

            n_step_buffer = []
            
            old_total_reward = 0
            step = 0
            while not done: 
                step +=1
                total_steps+=1

                act = self.choose_action(obs)['action']
                act = act.cpu().detach().numpy()


                if self.is_continuous:
                    act += np.random.normal(scale=0.1,size=act.shape)
            
               #next_obs, reward, terminate, trunc, info = env.step(act.squeeze()[np.newaxis])
                next_obs, reward, terminate, trunc, info = env.step(act.squeeze())
                done = terminate | trunc
                if config_dict['normalise_state']:
                    next_obs = (next_obs - self.replay_buffer.mean)/self.replay_buffer.std

                
                if self.transform_reward:
                   #reward = reward - 1
                    reward = reward + 1

                total_reward += reward

                if self.n_steps == 1:
                    self.replay_buffer.store_transition(obs,next_obs,act,reward,terminate,trunc)
                else:
                    n_step_buffer.append((obs,act,reward))
                    if len(n_step_buffer) == self.n_steps:
                        state_0, action_0, _ =  n_step_buffer[0]
                        disc_returns = np.sum([r * self.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                        self.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc)
                        n_step_buffer.pop(0)
                
                if terminate:
                    while n_step_buffer:
                        state_0, action_0, _ = n_step_buffer[0]
                        disc_returns = np.sum([r * self.gamma ** count for count, (_, _, r) in enumerate(n_step_buffer)], axis=0)
                        self.replay_buffer.store_transition(state_0, next_obs, action_0, disc_returns, terminate, trunc)

                        n_step_buffer.pop(0)


                if total_steps%self.update_ratio == 0:
                    loss = self.learn()
                    if wandb.run is not None and self.total_it%self.wandb_log_iter == 0 :
                        self.log_dict['total_step'] = self.total_it
                        wandb.log(self.log_dict)

                obs = next_obs

                if total_steps%config_dict['eval_counter'] == 0:
                    print('time:',time.time()-v)
                    v = time.time()
                    print(loss)
                    self.evaluate_performance(config_dict,total_steps)

            if self.transform_reward:
               #print('negative: ',total_reward + 1000)
                print('positive: ',total_reward - 1000)
            else:
                print(total_reward)

    def train_offline(self, config_dict):

        loss_hist = []
        env = config_dict['env']

        s = time.time()
        i = 0

        while i < config_dict['num_env_steps']:

            loss = self.learn(dep_targ=config_dict['dep_targ'])
            if wandb.run is not None and self.total_it%self.wandb_log_iter == 0 :
                self.log_dict['total_step'] = self.total_it
                wandb.log(self.log_dict)


            if (i+1)%config_dict['eval_counter'] == 0:
                
                if loss is not None:
                    print(loss)

                print(f'\nIteration: {i+1}')
                self.evaluate_performance(config_dict,iteration=i)


                print(time.time()-s)
                s = time.time()

            i+=1

        wandb.finish()



    def create_filepath(self, path, file_ext=None, path_list=None):
        
        path = '/montana-storage04/fast/Users/neggat_n/attention_net/'+path
        if path_list is None:
            path_list = [path, self.algo_name, self.env_id]


        if 'td3' in self.algo_name:
            file_name = f'actor_{self.ensemble_num}-critic_{self.ensemble_num*self.critic_factor}'
        else:
            file_name = f'critic_{self.ensemble_num*self.critic_factor}'
           #if getattr(self,'w_BC',None) and 'bc' not in self.algo_name:
           #    file_name += '-BC'
            file_name += f'-eps_{self.replay_buffer.discrete_eps}-bins_{self.replay_buffer.discrete_bins}'

        file_name += f'-seed_{self.seed}'

        if file_ext is not None:
            file_name += file_ext

        file_path = ''
        for path in path_list:
            file_path = os.path.join(file_path,path)
            if not os.path.exists(file_path):
                os.makedirs(file_path)

        file_path = os.path.join(file_path,file_name)
        return file_path
		


class ContinuousBaseAgent(BaseAgent):

    def __init__(self, obs_dims, action_dims, batch_size, algo_name, gamma=None, 
                tau=None, mem_size=None, dataset=None, **kwargs):

        super().__init__(obs_dims=obs_dims, action_dims=action_dims,
                        gamma=gamma, tau=tau, algo_name=algo_name, **kwargs)

        self.replay_buffer = ReplayBuffer(obs_dims=obs_dims,
                                        action_dims=action_dims,
                                        mem_size=mem_size,
                                        batch_size=batch_size,
                                        dataset=dataset,
                                        normalise_state=kwargs['normalise_state'],
                                        discrete_action=(not self.is_continuous))



class BaseActorCritic(ContinuousBaseAgent):

    ## I believe actor critic can only be used with continuous action spaces for differentiability of policy
    def __init__(self, obs_dims, action_dims, batch_size, algo_name, gamma=None, tau=None, mem_size=None, dataset=None, **kwargs):

        self.critic_lr = kwargs['critic_lr']
        self.actor_lr = kwargs['actor_lr']

        super().__init__(obs_dims=obs_dims, action_dims=action_dims, gamma=gamma, tau=tau,
                         algo_name=algo_name, mem_size=mem_size, batch_size=batch_size, dataset=dataset,
                         **kwargs)

    def _calc_critic_value(self,critic_values,log_probs=None,done_batch=None):

        if self.critic_factor != 1:
            if self.redQ:
                critic_ensemble = self.rng.choice(self.critic_factor,2,replace=False)
                critic_values = critic_values[:,critic_ensemble]
            critic_values = torch.min(critic_values,dim=1).values

        if done_batch is not None:
            critic_values[done_batch] = 0
            
        if getattr(self,'alpha',None) is not None and log_probs is not None:
            critic_values =  critic_values - self.alpha.detach()*log_probs.permute((1,0))

        return critic_values

    @abstractmethod
    def update_critic(self):
        pass

    @abstractmethod
    def update_actor(self):
        pass

class DiscreteBaseAgent(BaseAgent):

    def __init__(self, obs_dims, n_actions, gamma, tau, mem_size, batch_size, algo_name, dataset=None, **kwargs):

        super().__init__(obs_dims=obs_dims, gamma=gamma, tau=tau, algo_name=algo_name,**kwargs)

        self.replay_buffer = ReplayBuffer(mem_size=mem_size,
                                batch_size=batch_size,
                                obs_dims=obs_dims,
                                action_dims=n_actions,
                                normalise_state=kwargs['normalise_state'],
                                discrete_action=True)

