from collections import namedtuple

from open_source.rlpyt.rlpyt.envs.base import Env, EnvStep, EnvSpaces
from open_source.rlpyt.rlpyt.spaces.int_box import IntBox
from open_source.rlpyt.rlpyt.spaces.float_box import FloatBox
from open_source.rlpyt.rlpyt.utils.quick_args import save__init__args
from open_source.rlpyt.rlpyt.samplers.collections import TrajInfo

import uuid
import subprocess
import cmath
import re
# import paramiko
import time
import multiprocessing
import pdb

import tensorboardX

import numpy as np

# from utils.circuits_utils import get_mos_netlist_static_feature
from envs.StrongArm_Latch.strongArm_latch import strongArm_simulator

from envs.BaseEnv import BaseEnv
from copy import deepcopy

import yaml
import yamlordereddictloader
from envs.StrongArm_Latch import utils
import os

from envs.strongArmEnv_py import strongArmEnv_py
from tensorboardX import SummaryWriter
import copy

import time
import datetime

import pdb

EnvInfo = namedtuple("EnvInfo", ['timeout', 'taskID'])

class strongArmEnv_pcgrad(Env):
    def __init__(self, kwargs):
        self.global_stp = 0
        self.kwargs = copy.deepcopy(kwargs)

        self.log = kwargs.get('log', False)
        self.log_interval = kwargs.get('log_interval', 1)
        self.eval = kwargs.get('eval', False)
        corner_info = "%s_%s_%s" % (kwargs['corner']['process'], kwargs['corner']['temp'], kwargs['corner']['vdd'])
        self.encodeID = {'tt': 0, 'ff': 1, 'ss': 2, 'sf': 3, 'fs': 4}
        process_code = self.encodeID[kwargs['corner']['process']]
        temp_code = float(kwargs['corner']['temp'])
        vdd_code = float(kwargs['corner']['vdd'])
        self.taskID = [process_code, temp_code, vdd_code]
        # for each instance in different processes
        instance_path = str(uuid.uuid4())
        write_path = self.kwargs['log_dir'] + '/' + self.kwargs['runs_dir'] + '/' + corner_info + '-' + instance_path
        self.tensorboard = kwargs.get('tensorboard', True)
        if self.tensorboard:
            self.writer = SummaryWriter(log_dir=(write_path))
        else:
            self.writer = None
        # use 'runs_dir' to specify ocn run_dir
        time_info = "{}".format(datetime.datetime.today().strftime("%Y%m%d_%H%M"))
        tmp_path = time_info + "-" + instance_path
        self.kwargs['runs_dir'] = tmp_path
        self.strongArmEnv = strongArmEnv_py(tb_writer=self.writer, kwargs=self.kwargs)

        self.processes = ["tt", "ss", "ff", "sf", "fs"]
        self.vddmin = 0.9
        self.vddmax = 1.2
        self.tempmin = 0
        self.tempmax = 100

        corner = {
            'process': "tt",
            'temp': "27",
            'vdd': "1.2"
        }
        # cur_corner is a list of corners
        self.corner = kwargs.get('corner', corner)
        self.cur_corner = kwargs.get('cur_corner', self.corner)

        # action is a np.ndarray of dimension (13,)
        # self._action_space = FloatBox(low=-1, high=1, shape=13)
        self._init_obs = np.zeros(7)
        self._action_space = FloatBox(low=-1, high=1, shape=7)
        self._observation_space = FloatBox(low=-1., high=1., shape=7)



    def step(self, action):
        states, reward, episode_finish, all_info = self.strongArmEnv.step(action, episode=self.global_stp, global_stp=self.global_stp)
        # pdb.set_trace()
        # raw states is a list of size 360
        process = self.kwargs['corner']['process']
        vdd = float(self.kwargs['corner']['vdd'])
        vdd = np.atleast_1d(vdd)
        temp = float(self.kwargs['corner']['temp'])
        temp = np.atleast_1d(temp)
        # normalize pvt
        proc_index = self.processes.index(process)
        process = np.eye(len(self.processes), dtype=float)[proc_index]
        vdd = -1 + 2 * (vdd - self.vddmin) / (self.vddmax - self.vddmin)
        temp = -1 + 2 * (temp - self.tempmin) / (self.tempmax - self.tempmin)
        # pdb.set_trace()
        observation = np.concatenate((process, vdd))
        observation = np.concatenate((observation, temp))
        observation = np.float32(observation)
        # raw reward is of np float64
        reward = np.float32(reward)
        # raw episode_finish is bool
        episode_finish = True
        # raw all_info is a dict of size 7
        env_info = EnvInfo(timeout=False, taskID=self.taskID)
        if self.writer is not None and self.log:
            self.strongArmEnv.write_logs(all_info=all_info)
        self.global_stp += self.log_interval


        return EnvStep(observation=observation, reward=reward, done=episode_finish, env_info=env_info)

    def reset(self):
        process = self.corner['process']
        vdd = float(self.corner['vdd'])
        vdd = np.atleast_1d(vdd)
        temp = float(self.corner['temp'])
        temp = np.atleast_1d(temp)
        # normalize pvt
        proc_index = self.processes.index(process)
        process = np.eye(len(self.processes), dtype=float)[proc_index]
        vdd = -1 + 2 * (vdd - self.vddmin) / (self.vddmax - self.vddmin)
        temp = -1 + 2 * (temp - self.tempmin) / (self.tempmax - self.tempmin)
        # pdb.set_trace()
        observation = np.concatenate((process, vdd))
        observation = np.concatenate((observation, temp))
        observation = np.float32(observation)
        return observation.copy()

    def get_absolute_sizings(self, action):
        # do not use starting points, currently do not round
        # pdb.set_trace()

        return self.strongArmEnv.get_absolute_sizings(action=action,
                                                    config=self.strongArmEnv.config)

    @property
    def action_space(self):
        return self._action_space

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def spaces(self):
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )

    @property
    def horizon(self):
        pass

    def close(self):
        pass