import os
import pickle
import random
from collections import defaultdict

import numpy as np

from train.experience_replay.helper_functions import LockFile


class ExperienceSuccessMapping(object):
    def __init__(self, experiences_folder, success_mapping_file_path):
        self.experiences_folder = experiences_folder
        self.success_mapping_file_path = success_mapping_file_path

        self.success_mapping = None
        if 'success_mapping.p' in os.listdir(os.path.split(os.path.normpath(self.success_mapping_file_path))[0]):
            self.load_success_mapping()
        else:
            self.success_mapping = defaultdict()

        # For each file in the experiences_folder, make a key in the dict if it doesn't exist, yet.
        self.experience_files = set(os.listdir(self.experiences_folder))
        if 'success_mapping.p' in self.experience_files:
            self.experience_files.remove('success_mapping.p')
        for filename in self.experience_files:
            if filename not in self.success_mapping:
                self.success_mapping[filename] = np.array([0, 0])

        # save changes on disk
        self.save_success_mapping()

    def load_success_mapping(self):
        try:
            with open(self.success_mapping_file_path, 'rb') as f:
                self.success_mapping = pickle.load(f)
        except Exception as e:
            print(e)

    def save_success_mapping(self):
        with LockFile(self.success_mapping_file_path):
            with open(self.success_mapping_file_path, 'wb') as f:
                pickle.dump(self.success_mapping, f)

    def update(self, experience_replay_filename, result):
        """ result can be 0 or 1"""
        # read in current version of success_mapping
        with LockFile(self.success_mapping_file_path):
            with open(self.success_mapping_file_path, 'rb') as f:
                self.success_mapping = pickle.load(f)

                if experience_replay_filename not in self.success_mapping:
                    # initialize with (0, 0)
                    self.success_mapping[experience_replay_filename] = np.array([0, 0])
                else:
                    self.success_mapping[experience_replay_filename][0] += result
                    self.success_mapping[experience_replay_filename][1] += 1

            with open(self.success_mapping_file_path, 'wb') as f:
                pickle.dump(self.success_mapping, f)

    def sample(self):
        """
        Sample one experience file distributed according to the number of successes of replaying the files in the past
        """
        # update from disk
        self.load_success_mapping()

        # not all files in success_mapping.p are present in the current folder. Take the relevant subset:
        relevant_subset = {filename: self.success_mapping[filename] for filename in self.experience_files}

        weights = np.array([(k + 1) / (n + 1) for k, n in relevant_subset.values()])
        experience_file = random.choices(list(relevant_subset.keys()), weights=weights, k=1)
        return experience_file[0]

    def remove(self, filename):
        """ Remove key from self.success_mapping """
        self.success_mapping.pop(filename)