from collections import namedtuple

from open_source.rlpyt.rlpyt.envs.base import Env, EnvStep, EnvSpaces
from open_source.rlpyt.rlpyt.spaces.int_box import IntBox
from open_source.rlpyt.rlpyt.spaces.float_box import FloatBox
from open_source.rlpyt.rlpyt.utils.quick_args import save__init__args
from open_source.rlpyt.rlpyt.samplers.collections import TrajInfo

import uuid
import subprocess
import cmath
import re
# import paramiko
import time
import multiprocessing
import pdb

import tensorboardX

import numpy as np

# from utils.circuits_utils import get_mos_netlist_static_feature
from envs.StrongArm_Latch.strongArm_latch_py import strongArm_simulator_py

from envs.BaseEnv import BaseEnv
from copy import deepcopy

import yaml
import yamlordereddictloader
from envs.StrongArm_Latch import utils
import os

from envs.strongArmEnv_py import strongArmEnv_py
from tensorboardX import SummaryWriter
import copy

import time
import datetime
import multiprocessing as mp
import pdb

EnvInfo = namedtuple("EnvInfo", ['timeout'])

def env_step(env, action):
    return env.step(action, episode=0, global_stp=0)

class strongArmEnv_pvt(Env):
    def __init__(self, kwargs, writer=None):
        self.pool = mp.Pool(mp.cpu_count())
        self.parallel_writer = writer
        self.global_stp = 0
        self.kwargs = copy.deepcopy(kwargs)
        # for each instance in different processes
        instance_path = str(uuid.uuid4())
        time_info = "{}".format(datetime.datetime.today().strftime("%Y%m%d_%H%M"))
        tmp_path = time_info + "-" + instance_path
        runs_dir = self.kwargs['runs_dir']
        # for each instance in different processes
        processes = ["tt", "ss", "ff", "sf", "fs"]
        temps = ["0", "100"]
        vdds = ["1.2"]
        self.processes = kwargs.get('processes', processes)
        self.temps = kwargs.get('temps', temps)
        self.vdds = kwargs.get('vdds', vdds)
        self.log = kwargs.get('log', False)
        self.log_interval = kwargs.get('log_interval', 1)
        self.random = kwargs.get('random', False)
        self.parallel = kwargs.get('parallel', True)

        write_path_root = self.kwargs['log_dir'] + '/' + self.kwargs['runs_dir']
        corners = {}
        self.strongArmEnvs = {}
        for process in self.processes:
            for temp in self.temps:
                for vdd in self.vdds:
                    corners["corner_%s_%s_%s"%(process, temp, vdd)] = {
                        "process" : process,
                        "temp" : temp,
                        "vdd" : vdd
                    }
                    kwargs = deepcopy(self.kwargs)
                    kwargs['runs_dir'] = tmp_path + "-%s_%s_%s" % (process, temp, vdd)
                    kwargs['corner'] = corners["corner_%s_%s_%s" % (process, temp, vdd)]
                    if not self.parallel:
                        write_path = write_path_root + "/%s_%s_%s" % (process, temp, vdd)
                        writer = SummaryWriter(log_dir=(write_path))
                        self.strongArmEnvs["corner_%s_%s_%s" % (process, temp, vdd)] = strongArmEnv_py(tb_writer=writer,
                                                                                                       kwargs=kwargs)
                    else:
                        self.strongArmEnvs["corner_%s_%s_%s" % (process, temp, vdd)] = strongArmEnv_py(tb_writer=None,
                                                                                                       kwargs=kwargs)

        self._init_obs = np.zeros(7)
        # action is a np.ndarray of dimension (7,)
        # debug: only have 1 dimension
        self._action_space = FloatBox(low=-1, high=1, shape=7)
        self._observation_space = FloatBox(low=-1., high=1., shape=7)

    def get_absolute_sizings(self, action):
        # do not use starting points, currently do not round
        # pdb.set_trace()
        env = self.strongArmEnvs["corner_%s_%s_%s" % (self.processes[0], self.temps[0], self.vdds[0])]
        absolute_sizings = env.sizing_lower + (env.sizing_upper - env.sizing_lower) * (action + 1) / 2
        # absolute_sizings = self.round_sizings(absolute_sizings)
        # print(absolute_sizings)
        sizing_dict = dict()
        for i, varnames in enumerate(list(env.sizing.keys())):
            sizing_dict[varnames] = absolute_sizings[i]
            # eg1 = [absolute_sizings[i]]

        return sizing_dict

    def __getstate__(self):
        self_dict = self.__dict__.copy()
        del self_dict['pool']
        return self_dict

    def step(self, action):
        if not self.parallel:
            rewards = 0
            for env in self.strongArmEnvs.values():
                states, reward, episode_finish, all_info = env.step(action, episode=self.global_stp, global_stp=self.global_stp)
                rewards += reward
                if env.writer is not None and self.log:
                    env.write_logs(all_info=all_info)
            reward = np.float32(rewards) / len(self.strongArmEnvs)
        else:
            inputs = []
            env_ids = []
            for env_id, env in self.strongArmEnvs.items():
                inputs.append((env, action))
                env_ids.append(env_id)
            step_outputs = self.pool.starmap(env_step, inputs)
            states, rewards, episode_finishs, all_infos = list(map(list, zip(*step_outputs)))
            reward = sum(rewards) / len(rewards)
            ### parallel logging ###
            if self.parallel_writer is not None and self.log:
                reward_dict = {env_ids[i]: rewards[i] for i in range(len(env_ids))}
                self.parallel_write_logs(reward_dict, self.global_stp)

        observation = np.float32(self._init_obs)
        # raw episode_finish is bool
        episode_finish = True
        # raw all_info is a dict of size 7
        env_info = EnvInfo(timeout=False)
        # self.strongArmEnv.write_logs(all_info=all_info)
        self.global_stp += self.log_interval

        return EnvStep(observation=observation, reward=reward, done=episode_finish, env_info=env_info)

    def parallel_write_logs(self, reward_dict:dict, stp:int):
        print('tensorboard step:', stp)
        for corner_info, reward in reward_dict.items():
            self.parallel_writer.add_scalar(corner_info+'reward', reward, stp)

    def reset(self):
        init_obs = np.float32(np.array(self._init_obs))
        return init_obs.copy()

    @property
    def action_space(self):
        return self._action_space

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def spaces(self):
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )

    @property
    def horizon(self):
        pass

    def close(self):
        pass