import copy
import json
from multiprocessing import Queue
from collections import defaultdict

import numpy as np

from .interfaces import Grid
from utils.collections_util import defaultdict_to_dict
from .abstraction_mode import *


def score_fn(info, return_, min_avg_return, max_avg_return):

    info['return'] += (return_ - info['return']) / info['times']

    avg_return = info['return']
    normal_scale = max_avg_return - min_avg_return
    score = (avg_return - min_avg_return) / normal_scale
    info['score'] = np.clip(score, 0, 1)

    return info

def v_update_fn(info, rewards, next_info, alpha=0.1, gamma=0.99):

    v_s = info.get('v_value', 0)
    v_s_next = next_info.get('v_value', 0)

    g = sum(gamma ** i * rewards[i] for i in range(len(rewards))) + gamma ** len(rewards) * v_s_next
    v_s += alpha * (g - v_s)
    info['v_value'] = v_s
    return info

def q_update_fn(info, rewards, next_info, alpha=0.1, gamma=0.99):

    q_sa = info.get('q_value', 0)
    q_s_next = [actions_info['q_value'] for _, actions_info in next_info.items()]

    g = sum(gamma ** i * rewards[i] for i in range(len(rewards))) + gamma ** len(rewards) * max(q_s_next)
    q_sa += alpha * (g - q_sa)
    info['q_value'] = q_sa
    return info


class ScoreInspector:

    def __init__(
            self, step, grid_num, raw_state_dim, state_dim, state_min, state_max, action_dim, action_min, action_max,
            mode, reduction, necsa_lr, necsa_gamma, necsa_advantage, score_type, env_type='any', *args, **kwargs):

        self.step = step
        self.grid_num = grid_num
        self.raw_state_dim = raw_state_dim
        self.state_dim = state_dim
        self.state_min = state_min
        self.state_max = state_max
        self.action_dim = action_dim
        self.action_min = action_min
        self.action_max = action_max
        self.mode = mode
        self.reduction = reduction
        self.necsa_lr = necsa_lr
        self.necsa_gamma = necsa_gamma

        self.necsa_advantage = necsa_advantage
        self.score_type=score_type
        self.env_type = env_type

        self.score_avg = None

        self.s_token = Queue(10)

        self.setup()

        if self.mode in [HS_A, HS_HA]:
            raise NotImplemented(f"{HS_A} and {HS_HA} are not implemented yet")


    def setup(self):

        self.min_state = np.array([self.state_min for _ in range(self.raw_state_dim)])
        self.max_state = np.array([self.state_max for _ in range(self.raw_state_dim)])
        self.min_action = np.array([self.action_min for _ in range(self.action_dim)])
        self.max_action = np.array([self.action_max for _ in range(self.action_dim)])

        if self.reduction:
            self.setup_projection_matrix()
            self.min_state = np.dot(self.min_state, self.state_project_matrix)
            self.max_state = np.dot(self.max_state, self.state_project_matrix)
            self.min_action = np.dot(self.min_action, self.action_project_matrix)
            self.max_action = np.dot(self.max_action, self.action_project_matrix)

        self.min_avg_return = 0
        self.max_avg_return = 1
        # if self.env_type != 'atari':
        #     self.max_avg_return = 1000

        self.score_avg = 0

        self.abs_info = dict()

        self.state_grid = Grid(self.min_state, self.max_state, self.grid_num)
        self.action_grid = Grid(self.min_action, self.max_action, self.grid_num)

    def setup_projection_matrix(self):
        self.state_project_matrix = np.random.uniform(0, 0.1, (self.raw_state_dim, self.state_dim))
        if self.mode == HS and self.env_type == 'atari':
            self.state_project_matrix = np.random.uniform(-1, 1, (self.raw_state_dim, self.state_dim))
        self.action_project_matrix = np.random.uniform(0, 0.1, (self.action_dim, self.action_dim))

    def save(self, env_name):
        with open(env_name + '.json', 'w') as f:
            json.dump(self.abs_info, f)

    def load(self, env_name):
        with open(env_name + '.json', 'r') as f:
            self.abs_info = json.load(f)

    def discretize_states(self, states):
        abs_states = self.state_grid.state_abstract(states)
        return abs_states

    def discretize_actions(self, actions):
        if self.env_type == 'atari':
            return list(map(str, actions))

        abs_actions = self.action_grid.state_abstract(actions)
        return abs_actions

    def get_baseline(self, abs_pattern):
        baseline = np.mean([
            actions_info[self.score_type]
            for abs_action, actions_info in self.abs_info[abs_pattern].items()
        ])
        return baseline

    def inquire(self, abs_pattern, abs_action):
        try:
            score = self.abs_info[abs_pattern][abs_action][self.score_type]
            if self.score_type == 'score':
                if self.necsa_advantage:
                    return score - self.get_baseline(abs_pattern)
                else:
                    return score - self.score_avg
            else:
                return score
        except KeyError:
            return None

    def sync_scores(self):
        if self.s_token.qsize() > 0:

            new_abs_info, min_avg_return, max_avg_return = self.s_token.get()

            self.abs_info.update(new_abs_info)

            if min_avg_return < self.min_avg_return:
                self.min_avg_return = min_avg_return
            if max_avg_return > self.max_avg_return:
                self.max_avg_return = max_avg_return

            if self.mode in ACTION_ENABLED:
                self.score_avg = np.mean([
                    abs_action_info['score']
                    for abs_state, abs_state_info in self.abs_info.items()
                    for abs_action, abs_action_info in abs_state_info.items()
                ])
            else:
                self.score_avg = np.mean([
                    abs_state_info['score']
                    for abs_state, abs_state_info in self.abs_info.items()
                ])

            # print('############################################################')
            # #print('Abstract states :\t', self.states_info)
            # print('Abstract states number :\t', len(self.states_info.keys()))
            # print('Average states score :\t', self.score_avg)
            # print('Queue size :\t',self.s_token.qsize())
            # print('min and max return', self.min_avg_return, self.max_avg_return)
            # print('############################################################')

    def start_pattern_abstract(self, states, actions, rewards):

        states = np.array(states)
        actions = np.array(actions)

        self.pattern_abstract(states, actions, rewards)
        # t = Process(target=self.pattern_abstract, args=(states, actions, rewards))
        # t.daemon = True
        # t.start()

    def pattern_abstract(self, states, actions, rewards):

        abs_states = self.discretize_states(states)
        abs_actions = self.discretize_actions(actions)
        min_avg_return = self.min_avg_return
        max_avg_return = self.max_avg_return

        if self.mode in ACTION_ENABLED:
            new_abs_info = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        else:
            new_abs_info = defaultdict(lambda: defaultdict(int))

        return_ = sum(np.average(rewards, axis=1))
        if return_ < self.min_avg_return:
            min_avg_return = return_
        if return_ > self.max_avg_return:
            max_avg_return = return_

        next_state_pattern = None
        for i in range(len(abs_states)-1-self.step, -1, -1):
            if i - self.step < 0:
                break

            state_pattern = abs_states[i:i + self.step]
            action_pattern = abs_actions[i + self.step - 1:i + self.step]
            state_pattern = '-'.join(state_pattern)
            action_pattern = '-'.join(action_pattern)
            local_rewards = np.average(rewards[i:i + self.step], axis=1)

            info = state_info = new_abs_info[state_pattern] = new_abs_info.get(
                state_pattern, self.abs_info.get(state_pattern, defaultdict(int)))

            action_info = None
            if self.mode in ACTION_ENABLED:
                info = action_info = state_info[action_pattern] = state_info.get(action_pattern, defaultdict(int))

            info['times'] += 1

            score_fn(info, return_, min_avg_return, max_avg_return)

            if next_state_pattern is None:
                next_state_pattern = state_pattern
                continue

            # if self.mode in ACTION_ENABLED:
            #     info = action_info
            #     next_info = (self.abs_info.get(next_state_pattern, {}) | new_abs_info.get(next_state_pattern))
            #     q_update_fn(info, local_rewards, next_info, alpha=self.necsa_lr, gamma=self.necsa_gamma)
            # else:
            #     info = state_info
            #     next_info = new_abs_info.get(next_state_pattern)
            #     v_update_fn(info, local_rewards, next_info, alpha=self.necsa_lr, gamma=self.necsa_gamma)

            next_state_pattern = state_pattern

        # defaultdict doesn't work in queues
        new_abs_info = defaultdict_to_dict(new_abs_info)

        self.s_token.put((new_abs_info, min_avg_return, max_avg_return))


class Abstracter:

    def __init__(self, step, epsilon, *args, **kwargs):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones  = []
        self.step = step
        self.epsilon = epsilon
        self.inspector = None

    def dim_reduction(self, state):
        small_state = np.dot(state, self.inspector.state_project_matrix)
        return  small_state

    def append(self, state, action, reward, done):

        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)

        if done:
            if self.inspector.reduction:
                self.states = self.dim_reduction(self.states)
            self.inspector.start_pattern_abstract(self.states, self.actions, self.rewards)
            self.clear()

    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones  = []

    def handle_pattern(self, states, actions, rewards):
        abs_state = self.inspector.discretize_states(states)
        abs_action = self.inspector.discretize_actions(actions)

        if len(abs_state) != self.step:
            return rewards[0]
        abs_state = '-'.join(abs_state)
        abs_action = '-'.join(abs_action)
        score = self.inspector.inquire(abs_state, abs_action)

        if score is not None:
            delta = score * self.epsilon * 10
            rewards[0] += delta

        return rewards[0]

    def reward_shaping(self, state_list, action_list, reward_list):

        if self.inspector.reduction:
            state_list = self.dim_reduction(state_list)

        shaping_reward_list = copy.deepcopy(reward_list)

        for i in range(len(state_list) - self.step):
            target_states = state_list[i:i + self.step]
            target_action = action_list[i + self.step - 1:i + self.step]
            target_rewards = reward_list[i:i + self.step]

            shaped_reward = self.handle_pattern(target_states, target_action, target_rewards)
            shaping_reward_list[i] = shaped_reward

        shaping_reward_list = np.array(shaping_reward_list)
        return shaping_reward_list
