import yaml
import pickle
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import networkx as nx
from networkx.drawing.nx_agraph import to_agraph
from utils import recursive_dictify


class Logger:

    def __init__(self, args):
        self.writer = SummaryWriter(args.working_dir)
        self.working_dir = args.working_dir
        with open(args.working_dir + '/params.yml', 'w') as yaml_file:
            yaml.dump(recursive_dictify(vars(args)), yaml_file, default_flow_style=False)
        self.metrics = {}
        self.all_metrics = {}
        self.env_metrics = {}

    def add_text(self, tag, text_string):
        self.writer.add_text(tag, text_string)

    def log_env(self, info, global_step):
        if info is None:
            return
        env_metrics = {
            "episodic_return": info["episode"]["r"],
            "episodic_length": info["episode"]["l"],
            "n_accept": info["n_accept"]
        }
        print(f"Global Step={global_step}, Episodic Return={info['episode']['r']}")
        self.env_metrics = {k: [v] if k not in self.env_metrics else self.env_metrics[k] + [v] for k, v in env_metrics.items()}

    def log_train(self, train_metrics, global_step):
        print("SPS:", train_metrics['SPS'])
        self.update_metrics(train_metrics, global_step)
        self.env_metrics = {k: np.mean(v) for k, v in self.env_metrics.items()}
        self.update_metrics(self.env_metrics, global_step)
        self.env_metrics = {}
    
    def log_eval(self, eval_metrics, global_step):
        self.update_metrics(eval_metrics, global_step)
        print("Eval Return=", eval_metrics['eval_return'])
    
    def update_metrics(self, new_metrics, global_step):
        for k, v in new_metrics.items():
            self.writer.add_scalar(k, v, global_step)
        self.metrics.update(new_metrics)
        for k, v in new_metrics.items():
            if k in self.all_metrics:
                self.all_metrics[k].append((global_step, v))
            else:
                self.all_metrics[k] = [(global_step, v)]

    def get_metrics(self):
        return self.metrics
    
    def log_hparams(self, args):
        self.add_text(
            'hyperparameters',
            '|param|value|\n|-|-|\n%s' % ('\n'.join([f'|{k}|{v}|' for k, v in vars(args).items()])),
        )

    def close(self):
        self.writer.close()
        with open(self.working_dir+'/all_metrics.pkl', 'wb') as f:
            pickle.dump(self.all_metrics, f)

    def log_ldba(self, envs):
        return  # requires additional libraries
        try:
            args = (envs.call('unwrapped')[0]._ldba.graph, envs.call('unwrapped')[0]._ldba.aps, self.working_dir)
        except:
            args = None
            print('Could not render LDBA.')
        if args:
            draw_ldba(*args)
    
    def store(self, obs):
        pass


def draw_ldba(graph, aps, dir):
    n_attrs = {k: {'label': f'{k}\n{v:.2f}'} for k, v in nx.get_node_attributes(graph, 'value').items()}
    nx.set_node_attributes(graph, n_attrs)
    graph.graph['graph'] = {'ranksep': 3}
    for i, j, k, d in graph.edges(keys=True, data=True):
        if d['jump_id'] > 0:
            graph[i][j][k]['color'] = 'blue'
        condition = d['condition_as_str']
        for ap, n in list(zip(aps, range(len(aps))))[::-1]:
            condition = condition.replace(str(n), ap)
        graph[i][j][k]['label'] = condition # + f'\n{d["likelihood"]:.2f}'
    for v, d in graph.nodes(data=True):
        graph.nodes[v]['color'] = 'red' if d['accepting'] else 'black'
    A = to_agraph(graph) 
    A.layout('dot')  # ’neato’|’dot’|’twopi’|’circo’|’fdp’|’nop’
    A.draw(dir+'/ldba.png')
