from abc import ABC, abstractmethod
import os
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from distributionalrl.memory import LazyMultiStepMemory, \
    LazyPrioritizedMultiStepMemory
from distributionalrl.utils import RunningMeanStats, LinearAnneaer, state_pipeline


class BaseAgent(ABC):

    def __init__(self, env, test_env, log_dir, frame_stack, num_steps=5*(10**7), algorithm="TD",
                 batch_size=32, memory_size=10**6, gamma=0.99, multi_step=1,
                 update_interval=4, target_update_interval=10000,
                 start_steps=50000, epsilon_train=0.01, epsilon_eval=0.001,
                 epsilon_decay_steps=250000, double_q_learning=False, decision_function="greedy",
                 dueling_net=False, noisy_net=False, epistemic_method="ensemble", use_per=False,
                 log_interval=100, eval_interval=250000, num_eval_steps=125000,
                 max_episode_steps=27000, grad_cliping=5.0, cuda=True, seed=0, epi_memory_size=10**5, noise_hyps=(0.0, 0.0)
                 ):

        self.env = env
        self.test_env = test_env
        self.frame_stack = np.abs(frame_stack)
        self.algorithm = algorithm

        torch.manual_seed(seed)
        np.random.seed(seed)
        self.seed = seed
        if(cuda and torch.cuda.is_available() == False):
            print("Warning : cuda is not available on this machine")

        self.device = torch.device(
            "cuda" if (cuda and torch.cuda.is_available()) else "cpu")
        
        print("(user selection :", cuda, " | cuda available :", torch.cuda.is_available(), ") -> selecting device ", self.device)

        self.online_net = None
        self.target_net = None

        # Replay memory which is memory-efficient to store stacked frames.
        if use_per:
            beta_steps = (num_steps - start_steps) / update_interval
            self.memory = LazyPrioritizedMultiStepMemory(
                memory_size, self.env.observation_space.shape, self.env.observation_space.dtype,
                self.device, gamma, multi_step, beta_steps=beta_steps)
            if(epistemic_method == "epinet"):
                self.epi_memory = LazyPrioritizedMultiStepMemory(
                    epi_memory_size, self.env.observation_space.shape, self.env.observation_space.dtype,
                    self.device, gamma, multi_step, beta_steps=beta_steps)
        else:
            self.memory = LazyMultiStepMemory(
                memory_size, self.env.observation_space.shape, self.env.observation_space.dtype,
                self.device, gamma, multi_step)
            if(epistemic_method == "epinet"):
                self.epi_memory = LazyMultiStepMemory(
                    epi_memory_size, self.env.observation_space.shape, self.env.observation_space.dtype,
                    self.device, gamma, multi_step)

        self.log_dir = log_dir
        self.model_dir = os.path.join(log_dir, 'model')
        self.summary_dir = os.path.join(log_dir, 'summary')
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)
        if not os.path.exists(self.summary_dir):
            os.makedirs(self.summary_dir)

        self.writer = SummaryWriter(log_dir=self.summary_dir)
        self.train_return = RunningMeanStats(log_interval)

        self.steps = 0
        self.learning_steps = 0
        self.episodes = 0
        self.best_eval_score = -np.inf
        self.num_actions = self.env.action_space.n
        self.num_steps = num_steps
        self.batch_size = batch_size
        self.noise_hyps = noise_hyps

        self.double_q_learning = double_q_learning
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net
        self.epistemic_method = epistemic_method
        self.decision_function = decision_function
        self.use_per = use_per

        self.log_interval = log_interval
        self.eval_interval = eval_interval
        self.num_eval_steps = num_eval_steps
        self.gamma_n = gamma ** multi_step
        self.start_steps = start_steps
        self.epsilon_train = LinearAnneaer(
            1.0, epsilon_train, epsilon_decay_steps)
        self.epsilon_eval = epsilon_eval
        self.update_interval = update_interval
        self.target_update_interval = target_update_interval
        self.max_episode_steps = max_episode_steps
        self.grad_cliping = grad_cliping

    def run(self):
        while True:
            self.train_episode()
            if self.steps > self.num_steps:
                break

    def is_update(self):
        return self.steps % self.update_interval == 0\
            and self.steps >= self.start_steps

    def is_random(self, eval=False):
        # Use e-greedy for evaluation.
        if self.steps < self.start_steps:
            return True
        if eval:
            return np.random.rand() < self.epsilon_eval
        if self.noisy_net:
            return False
        return np.random.rand() < self.epsilon_train.get()

    def update_target(self):
        self.target_net.load_state_dict(
            self.online_net.state_dict())
        
    def randomize(self, state, reward):
        if(self.env.observation_space.dtype == np.bool_):
            return state, reward + self.noise_hyps[1]*np.random.randn(1).item()
        else:
            return state + (self.env.observation_space.high - self.env.observation_space.low)*self.noise_hyps[0]*np.random.randn(*(np.array(state).shape)), reward + self.noise_hyps[1]*np.random.randn(1).item()

    def explore(self, probas=None):
        # Act with randomness.
        if(probas is None):
            return np.random.choice(self.num_actions)
        else:
            return torch.multinomial(probas, 1).item()
        
    def exploit(self, state):
        # Act without randomness.
        _state = state_pipeline(state, self.env.observation_space.shape, self.env.observation_space.dtype, self.device, unsqueeze=True)
        with torch.no_grad():
            action = self.epistemic_net.calculate_q(states=self.online_net.ensemble[0].calculate_state_embeddings(_state)).argmax().item() if (self.epistemic_method == "epinet") else self.online_net.calculate_q(states=_state, step=self.steps).argmax().item()
        return action
    
    def epsgreedy(self, state, eval=False):
        u = np.random.rand()
        if self.steps < self.start_steps:
            return self.explore()
        
        if eval:
            if(u < self.epsilon_eval):
                return self.explore()
            else:
                return self.exploit(state)
            
        if self.noisy_net:
            return self.exploit(state)
        
        if u < self.epsilon_train.get():
            return self.explore()
        else:
            return self.exploit(state)
        
    def softepsgreedy(self, state):
        
        if self.steps < self.start_steps:
            return self.explore()
        
        u = np.random.rand()
        if u < self.epsilon_train.get():
            return self.explore()
        else:
            _state = state_pipeline(state, self.env.observation_space.shape, self.env.observation_space.dtype, self.device, unsqueeze=True)
            with torch.no_grad():
                actions_probas = self.epistemic_net.calculate_q(states=self.online_net.ensemble[0].calculate_state_embeddings(_state))[0] if (self.epistemic_method == "epinet") else self.online_net.calculate_q(states=_state, step=self.steps)[0]
        
            #print(actions_probas)
            return self.explore(probas=actions_probas)

    @abstractmethod
    def learn(self):
        pass

    def save_models(self, save_dir):
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        torch.save(
            self.online_net.state_dict(),
            os.path.join(save_dir, 'online_net.pth'))
        
        if(self.algorithm == "TD"):
            torch.save(
                self.target_net.state_dict(),
                os.path.join(save_dir, 'target_net.pth'))
        if(self.epistemic_method == "epinet"):
            torch.save(
            self.epistemic_net.state_dict(),
            os.path.join(save_dir, 'epistemic_net.pth'))

    def load_models(self, save_dir):
        self.online_net.load_state_dict(torch.load(
            os.path.join(save_dir, 'online_net.pth')))
        self.target_net.load_state_dict(torch.load(
            os.path.join(save_dir, 'target_net.pth')))
        
        if(self.epistemic_method == "epinet"):
            self.epistemic_net.load_state_dict(torch.load(
                os.path.join(save_dir, 'epistemic_net.pth')))

    def train_episode(self):

        self.online_net.train()
        if(self.epistemic_method == "epinet"):
            self.epistemic_net.train()

        self.episodes += 1
        episode_return = 0.
        episode_steps = 0

        done = False
        state, _ = self.env.reset()

        state = self.randomize(state, 0.0)[0]

        #print(state)

        while (not done) and episode_steps <= self.max_episode_steps:
            # NOTE: Noises can be sampled only after self.learn(). However, I
            # sample noises before every action, which seems to lead better
            # performances.
            self.online_net.sample_noise()

            if(self.decision_function == "softgreedy"):
                action = self.softepsgreedy(state)
            else:
                action = self.epsgreedy(state)

            next_state, reward, term, trunc, _ = self.env.step(action)
            #print(next_state, reward)

            next_state, reward = self.randomize(next_state, reward)

            done = term or trunc

            # To calculate efficiently, I just set priority=max_priority here.
            if(self.epistemic_method == "epinet" and self.steps % 2 == 0):
                self.epi_memory.append(state, action, reward, next_state, done)
            else:
                self.memory.append(state, action, reward, next_state, done)

            self.steps += 1
            episode_steps += 1
            episode_return += reward
            state = next_state

            # Debug
            #if(self.steps > 50000):
                #print("step start : ", self.steps)

            self.train_step_interval()

            #if(self.steps > 50000):
                #print("step end : ", self.steps)

        # We log running mean of stats.
        self.train_return.append(episode_return)

        # We log evaluation results along with training frames = 4 * steps.
        if self.episodes % self.log_interval == 0:
            self.writer.add_scalar(
                'return/train', self.train_return.get(), self.frame_stack * self.steps)

            print(f'Episode: {self.episodes:<4}  '
              f'episode steps: {episode_steps:<4}  '
              f'return: {episode_return:<5.1f}')

    def train_step_interval(self):
        self.epsilon_train.step()

        if ((self.algorithm in ["TD", "SAT"]) and self.steps % self.target_update_interval == 0):
            self.update_target()

        if self.is_update():
            self.learn()

        if self.steps % self.eval_interval == 0:
            self.evaluate()
            self.save_models(os.path.join(self.model_dir, 'final'))
            self.online_net.train()
            if(self.epistemic_method == "epinet"):
                self.epistemic_net.train()

    def evaluate(self):
        self.online_net.eval()
        if(self.epistemic_method == "epinet"):
            self.epistemic_net.eval()
            
        num_episodes = 0
        num_steps = 0
        total_return = 0.0

        while num_steps <= self.num_eval_steps:
            state, _ = self.test_env.reset()

            state = self.randomize(state, 0.0)[0]

            episode_steps = 0
            episode_return = 0.0
            done = False
            while (not done) and episode_steps <= self.max_episode_steps:

                action = self.epsgreedy(state, eval=True)

                next_state, reward, term, trunc, _ = self.test_env.step(action)

                next_state, reward = self.randomize(next_state, reward)

                #print(next_state)
                #print(reward)
                done = term or trunc
                num_steps += 1
                episode_steps += 1
                episode_return += reward
                state = next_state

            num_episodes += 1
            total_return += episode_return

        mean_return = total_return / num_episodes

        if mean_return > self.best_eval_score:
            self.best_eval_score = mean_return
            self.save_models(os.path.join(self.model_dir, 'best'))

        # We log evaluation results along with training frames = 4 * steps.
        self.writer.add_scalar(
            'return/test', mean_return, self.frame_stack * self.steps)
        print('-' * 60)
        print(f'Num steps: {self.steps:<5}  '
              f'return: {mean_return:<5.1f}')
        print('-' * 60)

    def __del__(self):
        self.env.close()
        self.test_env.close()
        self.writer.close()