import os
import glob
import logging
import numpy as np
from copy import deepcopy

import time
import json
import textwrap
import logging
import logging.config
from datetime import datetime
from collections import defaultdict
from scipy import stats


import torch

from .utils import nested_getattr, nested_setattr
# from safelife import StreamingJSONWriter
# from safelife.side_effects import side_effect_score
# from safelife.render_text import cell_name

logger = logging.getLogger(__name__)


class StreamingJSONWriter(object):
    """
    Serialize streaming data to JSON.

    This class holds onto an open file reference to which it carefully
    appends new JSON data. Individual entries are input in a list, and
    after every entry the list is closed so that it remains valid JSON.
    When a new item is added, the file cursor is moved backwards to overwrite
    the list closing bracket.
    """
    def __init__(self, filename, encoder=json.JSONEncoder):
        if os.path.exists(filename):
            self.file = open(filename, 'r+')
            self.delimeter = ','
        else:
            self.file = open(filename, 'w')
            self.delimeter = '['
        self.encoder = encoder

    def dump(self, obj):
        """
        Dump a JSON-serializable object to file.
        """
        data = json.dumps(obj, cls=self.encoder)
        close_str = "\n]\n"
        self.file.seek(max(self.file.seek(0, os.SEEK_END) - len(close_str), 0))
        self.file.write("%s\n    %s%s" % (self.delimeter, data, close_str))
        self.file.flush()
        self.delimeter = ','

    def close(self):
        self.file.close()

class BaseAlgo(object):
    """
    Common methods for model checkpointing in pytorch.

    Attributes
    ----------
    checkpoint_directory : str
        The directory where checkpoints are stored. If not set, the checkpoint
        directory will be taken from ``self.data_logger.logdir``.
    data_logger : object
    num_steps : int
        Total number of training steps. It's assumed that subclasses will
        increment this in their training loops.
    checkpoint_interval : int
        Interval between subsequent checkpoints
    num_checkpoints : int
        Total number of checkpoints to maintain the logging directory.
        Older checkpoints that exceed this number are deleted.
    checkpoint_attribs : list
        List of attributes on the algorithm that ought to be saved at each
        checkpoint. This should be overridden by subclasses.
        Note that this implicitly contains ``num_steps``.
    """
    checkpoint_directory = None
    data_logger = None

    num_steps = 0

    checkpoint_interval = 100000
    max_checkpoints = 3
    checkpoint_attribs = []

    _last_checkpoint = -1
    _checkpoint_directory = None

    _tournament_log = "tournament-log.json"

    def tournament_JSON(self):

        self.tournament_log = StreamingJSONWriter(
            os.path.join(self.data_logger.logdir, self._tournament_log))
        return

    def update_tournament_log(self):

        self.champion_log = deepcopy(self.champion_dict)
        del self.champion_log['length_champ']['model']
        del self.champion_log['perf_champ']['model']
        del self.champion_log['effect_champ']['model']

        self.champion_log['actor_training_steps'] = self.num_steps

        return


    @property
    def checkpoint_directory(self):
        return self._checkpoint_directory or (
            self.data_logger and self.data_logger.logdir)

    @checkpoint_directory.setter
    def checkpoint_directory(self, value):
        self,_checkpoint_directory = value

    def get_all_checkpoints(self):
        """
        Return a sorted list of all checkpoints in the log directory.
        """
        chkpt_dir = self.checkpoint_directory
        if not chkpt_dir:
            return []
        files = glob.glob(os.path.join(chkpt_dir, 'checkpoint-*.data'))

        def step_from_checkpoint(f):
            try:
                return int(os.path.basename(f)[11:-5])
            except ValueError:
                return -1

        files = [f for f in files if step_from_checkpoint(f) >= 0]
        return sorted(files, key=step_from_checkpoint)

    def save_checkpoint_if_needed(self):
        if self._last_checkpoint < 0:
            self.save_checkpoint()
        elif self._last_checkpoint + self.checkpoint_interval < self.num_steps:
            self.save_checkpoint()
        else:
            pass  # already have a recent checkpoint

    def save_checkpoint(self):
        chkpt_dir = self.checkpoint_directory
        if not chkpt_dir:
            return

        data = {'num_steps': self.num_steps}
        for attrib in self.checkpoint_attribs:
            try:
                val = nested_getattr(self, attrib)
            except AttributeError:
                logger.error("Cannot save attribute '%s'", attrib)
                continue
            if hasattr(val, 'state_dict'):
                val = val.state_dict()
            data[attrib] = val

        path = os.path.join(chkpt_dir, 'checkpoint-%i.data' % self.num_steps)
        torch.save(data, path)
        logger.info("Saving checkpoint: '%s'", path)

        old_checkpoints = self.get_all_checkpoints()
        for old_checkpoint in old_checkpoints[:-self.max_checkpoints]:
            os.remove(old_checkpoint)

        self._last_checkpoint = self.num_steps

    def load_checkpoint(self, checkpoint_name=None):
        chkpt_dir = self.checkpoint_directory
        if checkpoint_name and os.path.dirname(checkpoint_name):
            # Path includes a directory.
            # Treat it as a complete path name and ignore chkpt_dir
            path = checkpoint_name
        elif chkpt_dir and checkpoint_name:
            path = os.path.join(chkpt_dir, checkpoint_name)
        else:
            checkpoints = self.get_all_checkpoints()
            path = checkpoints and checkpoints[-1]
        if not path or not os.path.exists(path):
            return

        if torch.cuda.is_available():
            checkpoint = torch.load(path)
        else:
            checkpoint = torch.load(path, map_location=torch.device('cpu'))

        for key, val in checkpoint.items():
            orig_val = nested_getattr(self, key, None)
            if hasattr(orig_val, 'load_state_dict'):
                orig_val.load_state_dict(val)
            else:
                try:
                    nested_setattr(self, key, val)
                except AttributeError:
                    logger.error("Cannot load key '%s'", key)

        self._last_checkpoint = self.num_steps

    def take_one_step(self, envs):
        """
        Take one step in each of the environments.

        Returns
        -------
        states : list
        actions : list
        rewards : list
        done : list
            Whether or not each environment reached its end this step.
        """
        raise NotImplementedError

    def take_one_step_safe(self, envs):
        """
        Take one step in each of the environments.

        Returns
        -------
        states : list
        actions : list
        rewards : list
        done : list
            Whether or not each environment reached its end this step.
        """
        raise NotImplementedError

    def take_one_step_triple(self, envs):
        """
        Take one step in each of the environments.

        Returns
        -------
        states : list
        actions : list
        rewards : list
        done : list
            Whether or not each environment reached its end this step.
        """
        raise NotImplementedError

    def take_one_step_champion(self, envs):
        """
        Take one step in each of the environments.

        Returns
        -------
        states : list
        actions : list
        rewards : list
        done : list
            Whether or not each environment reached its end this step.
        """
        raise NotImplementedError

    def run_episodes(self, envs, num_episodes=None):
        """
        Run each environment to completion.

        Note that no data is logged in this method. It's instead assumed
        that each environment has a wrapper which takes care of the logging.

        Parameters
        ----------
        envs : list
            List of environments to run in parallel.
        num_episodes : int
            Total number of episodes to run. Defaults to the same as number
            of environments.
        """
        if num_episodes is None:
            num_episodes = len(envs)
        num_completed = 0

        while num_completed < num_episodes:
            data = self.take_one_step(envs)
            num_in_progress = len(envs)
            new_envs = []
            for env, done in zip(envs, data.done):
                if done:
                    num_completed += 1
                if done and num_in_progress + num_completed > num_episodes:
                    num_in_progress -= 1
                else:
                    new_envs.append(env)
            envs = new_envs


    def run_episodes_safe(self, envs, num_episodes=None):
        """
        Run each environment to completion.

        Note that no data is logged in this method. It's instead assumed
        that each environment has a wrapper which takes care of the logging.

        Parameters
        ----------
        envs : list
            List of environments to run in parallel.
        num_episodes : int
            Total number of episodes to run. Defaults to the same as number
            of environments.
        """
        if num_episodes is None:
            num_episodes = len(envs)
        num_completed = 0

        while num_completed < num_episodes:
            data = self.take_one_step_safe(envs)
            num_in_progress = len(envs)
            new_envs = []
            for env, done in zip(envs, data.done):
                if done:
                    num_completed += 1
                if done and num_in_progress + num_completed > num_episodes:
                    num_in_progress -= 1
                else:
                    new_envs.append(env)
            envs = new_envs


    def run_episodes_triple(self, envs, num_episodes=None):
        """
        Run each environment to completion.

        Note that no data is logged in this method. It's instead assumed
        that each environment has a wrapper which takes care of the logging.

        Parameters
        ----------
        envs : list
            List of environments to run in parallel.
        num_episodes : int
            Total number of episodes to run. Defaults to the same as number
            of environments.
        """
        if num_episodes is None:
            num_episodes = len(envs)
        num_completed = 0

        while num_completed < num_episodes:
            data = self.take_one_step_triple(envs)
            num_in_progress = len(envs)
            new_envs = []
            for env, done in zip(envs, data.done):
                if done:
                    num_completed += 1
                if done and num_in_progress + num_completed > num_episodes:
                    num_in_progress -= 1
                else:
                    new_envs.append(env)
            envs = new_envs
            
    #####Tournament Utils
    def tournament(self, envs, num_episodes=None):

        ##Initialize Champion Dictionary
        if not hasattr(self, 'champion_dict'):
            self.champion_dict = {}
            self.champion_dict['length_champ'] = {}
            self.champion_dict['perf_champ'] = {}
            self.champion_dict['effect_champ'] = {}
            self.champion_dict['length_champ']['model'] = deepcopy(self.training_model)
            self.champion_dict['length_champ']['avg_length'] = 1000
            self.champion_dict['length_champ']['avg_performance'] = 0.0
            self.champion_dict['length_champ']['avg_side_effects'] = 2.0
            self.champion_dict['length_champ']['std_length'] = 0.0
            self.champion_dict['length_champ']['std_performance'] = 0.0
            self.champion_dict['length_champ']['std_side_effects'] = 0.0
            self.champion_dict['perf_champ']['model'] = deepcopy(self.training_model)
            self.champion_dict['perf_champ']['avg_length'] = 1000
            self.champion_dict['perf_champ']['avg_performance'] = 0.0
            self.champion_dict['perf_champ']['avg_side_effects'] = 2.0
            self.champion_dict['perf_champ']['std_length'] = 0.0
            self.champion_dict['perf_champ']['std_performance'] = 0.0
            self.champion_dict['perf_champ']['std_side_effects'] = 0.0
            self.champion_dict['effect_champ']['model'] = deepcopy(self.training_model)
            self.champion_dict['effect_champ']['avg_length'] = 1000
            self.champion_dict['effect_champ']['avg_performance'] = 0.0
            self.champion_dict['effect_champ']['avg_side_effects'] = 2.0
            self.champion_dict['effect_champ']['std_length'] = 0.0
            self.champion_dict['effect_champ']['std_performance'] = 0.0
            self.champion_dict['effect_champ']['std_side_effects'] = 0.0
            self.champion_dict['env_samples'] = len(envs)


            self.tournament_JSON()
            self.update_tournament_log()


        #### Run The Policy Scoring

        policy_length_avg, policy_length_std, \
        policy_perf_avg, policy_perf_std, \
        policy_effect_avg, policy_effect_std = self.training_policy_score(envs, num_episodes)

        #### Compare with Champions

        ### Length
        if policy_length_avg < self.champion_dict['length_champ']['avg_length']:
            self.champion_dict['length_champ']['model'] = deepcopy(self.training_model)
            self.champion_dict['length_champ']['avg_length'] = policy_length_avg
            self.champion_dict['length_champ']['avg_performance'] = policy_perf_avg
            self.champion_dict['length_champ']['avg_side_effects'] = policy_effect_avg
            self.champion_dict['length_champ']['std_length'] = policy_length_std
            self.champion_dict['length_champ']['std_performance'] = policy_perf_std
            self.champion_dict['length_champ']['std_side_effects'] = policy_effect_std

        ###Performance
        if policy_perf_avg > self.champion_dict['perf_champ']['avg_performance']:
            self.champion_dict['perf_champ']['model'] = deepcopy(self.training_model)
            self.champion_dict['perf_champ']['avg_length'] = policy_length_avg
            self.champion_dict['perf_champ']['avg_performance'] = policy_perf_avg
            self.champion_dict['perf_champ']['avg_side_effects'] = policy_effect_avg
            self.champion_dict['perf_champ']['std_length'] = policy_length_std
            self.champion_dict['perf_champ']['std_performance'] = policy_perf_std
            self.champion_dict['perf_champ']['std_side_effects'] = policy_effect_std

        ###Side Effects
        if policy_effect_avg < self.champion_dict['effect_champ']['avg_side_effects']:
            self.champion_dict['effect_champ']['model'] = deepcopy(self.training_model)
            self.champion_dict['effect_champ']['avg_length'] = policy_length_avg
            self.champion_dict['effect_champ']['avg_performance'] = policy_perf_avg
            self.champion_dict['effect_champ']['avg_side_effects'] = policy_effect_avg
            self.champion_dict['effect_champ']['std_length'] = policy_length_std
            self.champion_dict['effect_champ']['std_performance'] = policy_perf_std
            self.champion_dict['effect_champ']['std_side_effects'] = policy_effect_std


        if self.data_logger is not None:
            length_champ_data = {
                "length_avg": self.champion_dict['length_champ']['avg_length'],
                "length_std": self.champion_dict['length_champ']['std_length'],
                "performance_avg": self.champion_dict['length_champ']['avg_performance'],
                "performance_std": self.champion_dict['length_champ']['std_performance'],
                "side_effect_avg": self.champion_dict['length_champ']['avg_side_effects'],
                "side_effect_std": self.champion_dict['length_champ']['std_side_effects'],
            }

            perf_champ_data = {
                "length_avg": self.champion_dict['perf_champ']['avg_length'],
                "length_std": self.champion_dict['perf_champ']['std_length'],
                "performance_avg": self.champion_dict['perf_champ']['avg_performance'],
                "performance_std": self.champion_dict['perf_champ']['std_performance'],
                "side_effect_avg": self.champion_dict['perf_champ']['avg_side_effects'],
                "side_effect_std": self.champion_dict['perf_champ']['std_side_effects'],
            }

            effect_champ_data = {
                "length_avg": self.champion_dict['effect_champ']['avg_length'],
                "length_std": self.champion_dict['effect_champ']['std_length'],
                "performance_avg": self.champion_dict['effect_champ']['avg_performance'],
                "performance_std": self.champion_dict['effect_champ']['std_performance'],
                "side_effect_avg": self.champion_dict['effect_champ']['avg_side_effects'],
                "side_effect_std": self.champion_dict['effect_champ']['std_side_effects'],
            }

            logger.info(
                "n=%i: length_avg=%0.3g, length_std=%0.3g, perf_avg=%0.3g, perf_std=%0.3g, effect_avg=%0.3g, effect_std=%0.3g",
                self.num_steps,
                policy_length_avg,
                policy_length_std,
                policy_perf_avg,
                policy_perf_std,
                policy_effect_avg,
                policy_effect_std)


            self.data_logger.log_scalars(length_champ_data, self.num_steps, 'length-champion')
            self.data_logger.log_scalars(perf_champ_data, self.num_steps, 'performance-champion')
            self.data_logger.log_scalars(effect_champ_data, self.num_steps, 'side-effects-champion')

            self.update_tournament_log()
            self.tournament_log.dump(self.champion_log)




    def training_policy_score(self, envs, num_episodes=None):

        length_list = []
        perf_list = []
        side_effect_list = []
        
        if num_episodes is None:
            num_episodes = len(envs)
        num_completed = 0

        while num_completed < num_episodes:
            data = self.take_one_step(envs)
            num_in_progress = len(envs)
            new_envs = []
            for env, done, info in zip(envs, data.done, data.info):
                if done:
                    num_completed += 1

                    length_list.append(info['episode']['length'])
                    perf_list.append(float(info['episode']['reward'])/float(info['episode']['reward_possible']))
                    side_effect_list.append(float(info['episode']['side_effects']['life-green'][0])/\
                                        float(info['episode']['side_effects']['life-green'][1]))

                if done and num_in_progress + num_completed > num_episodes:
                    num_in_progress -= 1
                else:
                    new_envs.append(env)
            envs = new_envs

        return np.mean(length_list), np.std(length_list), np.mean(perf_list), np.std(perf_list) , \
               np.mean(side_effect_list), np.std(side_effect_list)
