import logging
import time
import random
import pickle
import os
from sys import maxsize

import torch
from tensorboardX import SummaryWriter
from baselines.common.schedules import LinearSchedule

from abp.utils import clear_summary_path
from abp.models import DQNModel
from abp.adaptives.common.prioritized_memory.memory import PrioritizedReplayBuffer, ReplayBuffer
import numpy as np

logger = logging.getLogger('root')
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
IntTensor = torch.cuda.IntTensor if use_cuda else torch.IntTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor


class SADQAdaptive_exp(object):
    """Adaptive which uses the SADQ algorithm"""

    def __init__(self, name, state_length, network_config, reinforce_config):
        super(SADQAdaptive_exp, self).__init__()
        self.name = name
        #self.choices = choices
        self.network_config = network_config
        self.reinforce_config = reinforce_config
        if self.reinforce_config.use_prior_memory:
            self.memory = PrioritizedReplayBuffer(self.reinforce_config.memory_size, 0.6)
        else:
            self.memory = ReplayBuffer(self.reinforce_config.memory_size)
        self.learning = True
        self.state_length = state_length
        
        # Global
        self.steps = 0
        self.match_history = []
        self.best_match = -maxsize
        self.episode = 0
        
        self.reset()

        reinforce_summary_path = self.reinforce_config.summaries_path + "/" + self.name

        if not self.network_config.restore_network:
            clear_summary_path(reinforce_summary_path)
        else:
            self.restore_state()

        self.summary = SummaryWriter(log_dir=reinforce_summary_path)
        
        self.true_q_model = DQNModel(self.name + "ture_q", self.network_config, use_cuda)
        self.target_model = DQNModel(self.name + "_target", self.network_config, use_cuda, is_sigmoid = True)
        self.eval_model = DQNModel(self.name + "_eval", self.network_config, use_cuda, is_sigmoid = True)

        self.beta_schedule = LinearSchedule(self.reinforce_config.beta_timesteps,
                                            initial_p=self.reinforce_config.beta_initial,
                                            final_p=self.reinforce_config.beta_final)

        self.epsilon_schedule = LinearSchedule(self.reinforce_config.epsilon_timesteps,
                                               initial_p=self.reinforce_config.starting_epsilon,
                                               final_p=self.reinforce_config.final_epsilon)

    def __del__(self):
        self.save()
        self.summary.close()

    def should_explore(self):
        self.epsilon = self.epsilon_schedule.value(self.steps)
        self.summary.add_scalar(tag='%s/Epsilon' % self.name,
                                scalar_value=self.epsilon,
                                global_step=self.steps)

        return random.random() < self.epsilon

    def predict(self, state, isGreedy = False, is_random = False):
        
        if self.learning:
            self.steps += 1
#         print(len(state))
#         state = np.unique(state, axis=0)
#         print(len(state))
        with torch.no_grad():
            q_values = FloatTensor(self.true_q_model.predict_batch(Tensor(state))[1]).view(-1)
            _, choice = q_values.max(0)
            
            q_values_eval = FloatTensor(self.eval_model.predict_batch(Tensor(state))[1]).view(-1)
            _, choice_eval = q_values_eval.max(0)
            
#             print(choice, choice_eval)
#             print((choice == choice_eval).item())
#             input()
            self.match_history.append((choice == choice_eval).item())
        
        if self.previous_state is not None and self.learning:
            self.memory.add(self.previous_state,
                            None,
                            self.current_reward,
                            state[choice.item()].copy(), 0)
        
        if self.learning and self.should_explore() and not isGreedy:
            q_values = None
            choice = random.choice(list(range(len(state))))
            action = choice
        else:
            action = choice
            
        if not self.learning:
            action = choice_eval
        
        if self.learning and self.steps % self.reinforce_config.replace_frequency == 0:
            logger.debug("Replacing target model for %s" % self.name)
            self.target_model.replace(self.eval_model)
            
        if (self.learning and
            self.steps > self.reinforce_config.update_start and
                self.steps % self.reinforce_config.update_steps == 0):
            self.update()

        self.current_reward = 0
        self.previous_state = state[action]

        return action, q_values

    def disable_learning(self, is_save = False):
        logger.info("Disabled Learning for %s agent" % self.name)
        if is_save:
            self.save()
        self.learning = False
#         self.episode = 0
        
    def enable_learning(self):
        logger.info("enabled Learning for %s agent" % self.name)
        self.learning = True
        self.reset()

    def end_episode(self, state):
        if not self.learning:
            return
        
#         state = Tensor(state).view(1, -1)
#         print(state.size())
#         with torch.no_grad():
#             q_values = FloatTensor(self.true_q_model.predict_batch(state)[1]).view(-1)
#         _, choice = q_values.max(0)
            
        self.memory.add(self.previous_state,
                        None,
                        self.current_reward,
                        state.copy(), 1)
#         print(self.current_reward)
        
        match_percent = sum(self.match_history) / len(self.match_history)
#         print(match_percent)
#         input()
        if match_percent > self.best_match and len(self.match_history) > 15:
            self.best_match = match_percent
            print("best")
            self.save(appendix="_best")
            
        self.reset()

    def reset(self):
        self.match_history = []
        self.current_reward = 0
        self.total_reward = 0
        self.previous_state = None
        self.previous_action = None

    def restore_state(self):
        restore_path = self.network_config.network_path + "/adaptive.info"
        if self.network_config.network_path and os.path.exists(restore_path):
            logger.info("Restoring state from %s" % self.network_config.network_path)

            with open(restore_path, "rb") as file:
                info = pickle.load(file)

            self.steps = info["steps"]
            self.best_match = info["best_match"]
            self.episode = info["episode"]
            self.memory.load(self.network_config.network_path)
            print("lenght of memeory: ", len(self.memory))

    def save(self, force=False, appendix=""):
        info = {
            "steps": self.steps,
            "best_match": self.best_match,
            "episode": self.episode
        }
        
        print("*************saved*****************")
        self.eval_model.save_network(appendix = appendix)
        self.target_model.save_network(appendix = appendix)
        with open(self.network_config.network_path + "/adaptive.info", "wb") as file:
            pickle.dump(info, file, protocol=pickle.HIGHEST_PROTOCOL)
        self.memory.save(self.network_config.network_path)
        print("lenght of memeory: ", len(self.memory))

    def reward(self, r):
        self.total_reward += r
        self.current_reward += r

    def update(self):
        if len(self.memory._storage) <= self.reinforce_config.batch_size:
            return
#         self.eval_model.train_mode()
        beta = self.beta_schedule.value(self.steps)
        self.summary.add_scalar(tag='%s/Beta' % self.name,
                                scalar_value=beta, global_step=self.steps)
        if self.reinforce_config.use_prior_memory:
            batch = self.memory.sample(self.reinforce_config.batch_size, beta)
            (states, actions, reward, next_states,
             is_terminal, weights, batch_idxes) = batch
            self.summary.add_histogram(tag='%s/Batch Indices' % self.name,
                                       values=Tensor(batch_idxes),
                                       global_step=self.steps)
        else:
            batch = self.memory.sample(self.reinforce_config.batch_size)
            (states, actions, reward, next_states, is_terminal) = batch

        states = FloatTensor(states)
        next_states = FloatTensor(next_states)
        terminal = FloatTensor([1 if t else 0 for t in is_terminal])
        reward = FloatTensor(reward)
        batch_index = torch.arange(self.reinforce_config.batch_size,
                                   dtype=torch.long)
        
#         print(next_states.size())
        # Current Q Values
        q_values = self.eval_model.predict_batch(states)[1]
        q_values = q_values.view(-1)
        # Calculate target
        q_next = self.target_model.predict_batch(next_states)[1]
#         q_max = torch.stack([each_qmax.max(0)[0].detach() for each_qmax in q_next], dim = 1)[0]
        q_next = q_next.view(-1)
        q_next = (1 - terminal) * q_next
#         print(reward.size())
        q_target = reward + self.reinforce_config.discount_factor * q_next
#         print(q_target.size())
        # update model
        self.eval_model.fit(q_values, q_target, self.steps)

        # Update priorities
        if self.reinforce_config.use_prior_memory:
            td_errors = q_values - q_target
            new_priorities = torch.abs(td_errors) + 1e-6  # prioritized_replay_eps
            self.memory.update_priorities(batch_idxes, new_priorities.data)
            
    def load_model_eval(self, model):
        self.eval_model.replace(model)
        
    def load_weight_true_q(self, weight_dict):
        self.true_q_model.load_weight(weight_dict)
