from abc import ABC
import pickle
import time
from .context import Context
import wandb
import numpy as np
import threading

lock = threading.Lock()
CHECKPOINT_FILE = "./weights/models/{}/last_checkpoint.out"
CHECKPOINTS_DIR = "./weights/models/{}/checkpoint_{}"


class Exp3(ABC):
    def __init__(self, config):
        self._version = 0
        self._type = config["type"]
        self._abr = config["abr"]
        self._lr = config['lr']
        self._observation_space = config["observation_space"]
        self._action_space = config["action_space"]
        self._time_to_save = config["saving_time"]
        self._contexts = [Context(num_of_arms=self._action_space, lr=self._lr)
                          for _ in range(self._observation_space)]

        self._time_left_to_save = self._time_to_save
        self._version = 0
        self._episodes = {}

        self._init_logger()

    def _init_logger(self):
        self._log = {
            "timestamps_total": 0,
            "episodes_total": 0,
            "episode_len": [],
            "episode_len_cum": [],
            "num_episode_failed": 0,
            "num_episode_failed_cum": 0,
            "episode_reward_avg": [],
            "episode_reward_avg_cum": [],
            "total_visits": [0 for _ in range(len(self._contexts))],
            "start_time": time.time()
        }

    def start_episode(self, episode_id):
        if episode_id in self._episodes:
            return False

        self._episodes[episode_id] = {
            "log": {
                "rewards": []
            },
            "last_arm": None,
            "last_context_idx": None
        }
        return episode_id

    def end_episode(self, episode_id):
        if episode_id not in self._episodes:
            return False
        
        if self._type == 'train':
            self._log["episodes_total"] += 1
            rewards = self._episodes[episode_id]["log"]["rewards"]

            # episode_failed = len(rewards) < 60*8/2.0 - 5
            # if episode_failed:
            #     self._log["num_episode_failed"] += 1
            #     self._log["num_episode_failed_cum"] += 1
            # else:
            self._log["episode_len"].append(len(rewards))
            self._log["episode_len_cum"].append(len(rewards))
            self._log["episode_reward_avg"].append(np.mean(rewards))
            self._log["episode_reward_avg_cum"].append(np.mean(rewards))

            with lock:
                self._time_left_to_save -= 1
                if self._time_left_to_save <= 0:
                    self._time_left_to_save = self._time_to_save
                    self.save()

        del self._episodes[episode_id]
        return True

    def get_action(self, episode_id, context_idx):
        arm = self._contexts[context_idx].predict()
        self._episodes[episode_id]["last_arm"] = arm
        self._episodes[episode_id]["last_context_idx"] = context_idx
        return arm

    def log_returns(self, episode_id, reward):
        if self._type != 'train':
            return
        arm = self._episodes[episode_id]["last_arm"]
        context_idx = self._episodes[episode_id]["last_context_idx"]
        self._contexts[context_idx].update(reward, arm)

        self._log["total_visits"][context_idx] += 1
        self._log["timestamps_total"] += 1
        self._episodes[episode_id]["log"]["rewards"].append(reward)

    def _log_wandb(self):
        def get_data(arr):
            if len(arr) == 0:
                return [0]
            return arr
        wandb.log(
            {
                "total_time_s": time.time() - self._log["start_time"],
                "timestamps_total": self._log["timestamps_total"],
                "episodes_total": self._log["episodes_total"],
                "num_episode_failed": self._log["num_episode_failed"],
                "num_episode_failed_cum": self._log["num_episode_failed_cum"],
                "episode_len_median": np.median(get_data(self._log["episode_len"])),
                "episode_len_min": np.min(get_data(self._log["episode_len"])),
                "episode_len_max": np.max(get_data(self._log["episode_len"])),
                "episode_len_cum_mean": np.mean(get_data(self._log["episode_len_cum"])),
                "episode_reward_mean": np.mean(get_data(
                    self._log["episode_reward_avg"])),
                "episode_reward_max": np.max(get_data(
                    self._log["episode_reward_avg"])),
                "episode_reward_min": np.min(get_data(
                    self._log["episode_reward_avg"])),
                "episode_reward_mean_cum": np.mean(get_data(
                    self._log["episode_reward_avg_cum"])),
                "episode_reward_max_cum": np.max(get_data(
                    self._log["episode_reward_avg_cum"])),
                "episode_reward_min_cum": np.min(get_data(
                    self._log["episode_reward_avg_cum"]))
            }, 
            step=self._version)

        # reset logger
        self._log["episode_reward_avg"] = []
        self._log["episode_len"] = []
        self._log["num_episode_failed"] = 0

        for i in range(len(self._log["total_visits"])):
            wandb.log({"visits_context_{}".format(
                i): self._log["total_visits"][i]}, step=self._version)

        for i, c in enumerate(self._contexts):
            for j, w in enumerate(c.weights):
                wandb.log({"weights_context_{}_weight_{}".format(
                    i, j): w}, step=self._version)

        wandb.log({}, commit=True)

    def save(self):
        data = {'version': self._version}
        for i, c in enumerate(self._contexts):
            data[i] = c.get_dict()

        checkpoint = CHECKPOINTS_DIR.format(self._abr, str(self._version).zfill(6))
        with open(checkpoint, 'wb') as f:
            pickle.dump(data, f)

        with open(CHECKPOINT_FILE.format(self._abr), 'w') as f:
            f.write(checkpoint)

        self._log_wandb()
        self._version += 1

    def restore(self, path):
        with open(path, "rb") as f:
            data = pickle.load(f)
            self._version = data['version'] + 1
            for i in data:
                if i == 'version':
                    continue
                i = int(i)
                self._contexts[i].load_dict(data[i])
