import copy
import numpy as np
import os
import glob
import pickle
import logging
from threading import RLock
from train.common.config import Config


class Stats:
    """
    Singleton Stats for global training / evaluation statistics.
    """
    __instance = None
    __mutex = RLock()

    def __init__(self, config):
        """ Virtually private constructor. """
        if Stats.__instance is not None:
            raise Exception("This class is a singleton!")
        self.config = config
        self.statistics = None
        self.unsaved_changes = None
        self.alpha = 0.7
        self.path = self.config.statistics_path
        logging.info('Stats: Created new statistics instance.')

    @staticmethod
    def get_instance(config: Config = None):
        """
        Returns singleton of the Stats class.
        :param config:
        :return:
        """
        Stats.__mutex.acquire()
        try:
            if Stats.__instance is None:
                Stats.__instance = Stats(config)
            return Stats.__instance
        finally:
            Stats.__mutex.release()

    def reset(self, consensus_code):
        """
        Resets all statistics back to default.
        :param consensus_code: Consensus code consisting of the letters to monitor the milestone progress.
        :return:
        """
        Stats.__mutex.acquire()
        try:
            self.statistics = {
                'current': {
                    'seed': None
                },
                'runs': {
                    'consensus': consensus_code,
                    'count': 0,
                    'trial_rewards': [],
                    'total_reward': 0,
                    'mean_reward': None,
                    'std_reward': None,
                    'min_reward': None,
                    'max_reward': None
                },
                'milestones': {},
                'resources': {}
            }
            for _, v in self.config.subtask.consensus_code.items():
                self.statistics['resources'][v] = 0.0
            for i in range(len(consensus_code)):
                self.statistics['milestones'][consensus_code[:i+1]] = 0.0
            self.unsaved_changes = True
            logging.info('Stats: Statistics reset!')
        finally:
            Stats.__mutex.release()

    def copy(self):
        """
        Returns a copy of the current statistics dictionary.
        :return: Statistics dictionary
        """
        return copy.deepcopy(self.statistics)

    def summarize_reward(self, rewards):
        """
        Computes the summary statistics for the reward.
        :param rewards: Dictionary containing the summary statistics.
        :return:
        """
        Stats.__mutex.acquire()
        try:
            logging.info('Stats: Compute reward statistics.')
            self.statistics['runs']['trial_rewards'].append(rewards['mean_reward'])
            self.statistics['runs']['total_reward'] += rewards['total_reward']

            if self.statistics['runs']['min_reward'] is None:
                self.statistics['runs']['min_reward'] = rewards['min_reward']
            if self.statistics['runs']['max_reward'] is None:
                self.statistics['runs']['max_reward'] = rewards['max_reward']

            self.statistics['runs']['mean_reward'] = np.mean(self.statistics['runs']['trial_rewards'])
            self.statistics['runs']['std_reward'] = np.std(self.statistics['runs']['trial_rewards'])
            self.statistics['runs']['min_reward'] = min(self.statistics['runs']['min_reward'], rewards['min_reward'])
            self.statistics['runs']['max_reward'] = max(self.statistics['runs']['max_reward'], rewards['max_reward'])
            self.unsaved_changes = True
        finally:
            Stats.__mutex.release()

    def add_resources(self, inventory):
        """
        Updates the resources entries according to the inventory dictionary.
        :param inventory: Dictionary consisting of the agent resources of its inventory.
        :return:
        """
        Stats.__mutex.acquire()
        try:
            for key, value in inventory.items():
                if key in self.statistics['resources']:
                    self.statistics['resources'][key] += value
            self.unsaved_changes = True
        finally:
            Stats.__mutex.release()

    def incr_milestone(self, consensus_sub_code):
        """
        Increments the current progress of the consensus sub code.
        :param consensus_sub_code: sub code to increment
        :return:
        """
        Stats.__mutex.acquire()
        try:
            logging.info('Stats: Incrementing milestone {}.'.format(consensus_sub_code))
            self.statistics['milestones'][consensus_sub_code] += 1
            self.unsaved_changes = True
        finally:
            Stats.__mutex.release()

    def incr_runs_count(self):
        """
        Increments the number of total runs.
        :return:
        """
        Stats.__mutex.acquire()
        try:
            logging.info('Stats: Incrementing total runs.')
            self.statistics['runs']['count'] += 1
            self.unsaved_changes = True
        finally:
            Stats.__mutex.release()

    def save(self):
        """
        Persists the statistics to the file system.
        :return:
        """
        Stats.__mutex.acquire()
        try:
            # create the dictionary to the task transition
            stats_path = os.path.join(self.path, 'stats.dict')
            if not os.path.exists(self.path):
                logging.info("Stats: Creating statistics path...")
                os.makedirs(self.path)

            # save statistics dictionary
            logging.info("Stats: Saving statistics object...")
            with open(stats_path, 'wb+') as f:
                pickle.dump(self.statistics, f)

            self.unsaved_changes = False
            logging.info('Stats: Saving complete.')
        finally:
            Stats.__mutex.release()

    @staticmethod
    def load(stats_file):
        with open(stats_file, 'rb') as file:
            return pickle.load(file)

    @staticmethod
    def calculate_ensemble_stats(root_dirs, minimum_consensus_match='S', replay_until_correction=False):
        collection = {
            'count': {},
            'milestones': {},
            'replay_until': {},
            'replay_entries': {},
            'regular_entries': {}
        }
        for root_dir in root_dirs:
            for file in glob.glob(os.path.join(root_dir, '**/*.dict'), recursive=True):
                try:
                    stats = Stats.load(file)
                    # check if the minimum consensus matches are satisfied
                    if minimum_consensus_match is not None and minimum_consensus_match not in stats['runs']['consensus']:
                        continue
                    # check for runs property
                    if 'runs' not in stats:
                        continue
                    cnt = stats['runs']['count']
                    # skip empty runs
                    if cnt == 0:
                        continue

                    # detect replay untils
                    marker = False
                    if stats['milestones']['S'] == 0:
                        for k, v in stats['milestones'].items():
                            if k not in collection['replay_until']:
                                collection['replay_until'][k] = 0
                            if v != 0:
                                marker = True
                                collection['replay_until'][k] += 1
                            if marker and v == 0:
                                break
                            if v == 0:
                                stats['milestones'][k] = -1

                    for k, v in stats['milestones'].items():
                        if k not in collection['regular_entries']:
                            collection['regular_entries'][k] = 0
                        if not marker and v != 0 and v != -1:
                            collection['regular_entries'][k] += 1

                    for k, v in stats['milestones'].items():
                        if k not in collection['milestones']:
                            collection['count'][k] = []
                            collection['milestones'][k] = []
                            collection['replay_entries'][k] = 0
                        if v != 0 and marker:
                            collection['replay_entries'][k] += 1
                        # ignore runs of replay until for the given stage
                        if v != -1:
                            collection['milestones'][k].append(v)
                            collection['count'][k].append(cnt)

                except Exception as e:
                    print(e, f"Skipping file: {file}")

        collection['summary'] = {}
        print('Statistics summary:')
        for k, v in collection['milestones'].items():
            val = np.array(v)/np.array(collection['count'][k])
            collection['summary'][k] = np.mean(val)

            if replay_until_correction:
                divisor = (collection['regular_entries'][k]+collection['replay_entries'][k])
                if divisor == 0:
                    divisor = 1
                weight = collection['replay_entries'][k]/divisor
                if k != 'S':
                    collection['summary'][k] = (1-weight)*collection['summary'][k] + weight*collection['summary'][k]*collection['summary'][k[:-1]]
                    collection['summary'][k] = min(collection['summary'][k], collection['summary'][k[:-1]])

            print(f"P({k}) = {collection['summary'][k]}")
        print('Used replay until correction P(x) = P(x) * P(x-1)!'
              if replay_until_correction else 'No replay until correction!')
        return collection


if __name__ == "__main__":
    Stats.calculate_ensemble_stats(root_dirs=['tmp/rec/stats/'],
                                   minimum_consensus_match='SSSSSSSPPPPP',
                                   replay_until_correction=False)
