from collections import namedtuple

from open_source.rlpyt.rlpyt.envs.base import Env, EnvStep, EnvSpaces
from open_source.rlpyt.rlpyt.spaces.int_box import IntBox
from open_source.rlpyt.rlpyt.spaces.float_box import FloatBox
from open_source.rlpyt.rlpyt.utils.quick_args import save__init__args
from open_source.rlpyt.rlpyt.samplers.collections import TrajInfo

import uuid
import subprocess
import cmath
import re
# import paramiko
import time
import multiprocessing
import pdb

import tensorboardX

import numpy as np

# from utils.circuits_utils import get_mos_netlist_static_feature
from envs.StrongArm_Latch.strongArm_latch import strongArm_simulator

from envs.BaseEnv import BaseEnv
from copy import deepcopy

import yaml
import yamlordereddictloader
from envs.StrongArm_Latch import utils
import os

from envs.NGspiceOpamp.ngspice_vanilla_opamp_log import TwoStageAmp_log
# from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
import copy

import time
import datetime

import pdb

EnvInfo = namedtuple("EnvInfo", ['timeout', 'taskID'])

class NGspiceOpampEnv_pcgrad_log(Env):
    def __init__(self, kwargs, writer):
        self.global_stp = 0
        self.kwargs = copy.deepcopy(kwargs)
        self.log = kwargs.get('log', False)
        self.log_interval = kwargs.get('log_interval', 1)
        self.eval = kwargs.get('eval', False)
        corner_info = "%s_%s_%s" % (kwargs['corner']['process'], kwargs['corner']['temp'], kwargs['corner']['vdd'])
        self.encodeID = {'tt':0, 'ff':1, 'ss':2, 'sf':3, 'fs':4}
        process_code = self.encodeID[kwargs['corner']['process']]
        temp_code = float(kwargs['corner']['temp'])
        vdd_code = float(kwargs['corner']['vdd'])
        self.taskID = [process_code, temp_code, vdd_code]
        # pdb.set_trace()
        # for each instance in different processes
        instance_path = str(uuid.uuid4())
        # time_info = "{}".format(datetime.datetime.today().strftime("%Y%m%d_%H%M"))

        write_path = self.kwargs['log_dir'] + '/' + self.kwargs['runs_dir'] + '/' + corner_info + '-' + instance_path
        # write_path = self.kwargs['log_dir']
        # tmp_path = time_info + "-" + instance_path
        # pdb.set_trace()
        self.tensorboard = kwargs.get('tensorboard', True)
        self.writer = writer
        self.processes = ["tt", "ss", "ff", "sf", "fs"]
        self.vddmin = 0.9
        self.vddmax = 1.2
        self.tempmin = 0
        self.tempmax = 100

        self.ckt_perf = {}

        corner = {
            'process': "tt",
            'temp': "27",
            'vdd': "1.2"
        }
        # cur_corner is a list of corners
        self.corner = kwargs.get('corner', corner)

        # use 'runs_dir' to specify ocn run_dir
        # self.kwargs['runs_dir'] = tmp_path
        # self.TwoStageAmpEnv = TwoStageAmp(tb_writer=self.writer, kwargs=self.kwargs)
        env_config = {"generalize": True, "valid": True, "corner": self.corner, "log": self.log, "corner_info":corner_info}
        self.TwoStageAmpEnv = TwoStageAmp_log(env_config, tb_writer=self.writer)
        self.TwoStageAmpEnv.reset()

        # self._init_obs = np.zeros(162)
        self._init_obs = np.zeros(7)
        # action is a np.ndarray of dimension (13,)
        # self._action_space = FloatBox(low=-1, high=1, shape=13)
        # debug: only have 1 dimension
        self._action_space = FloatBox(low=-1, high=1, shape=7)
        # self._observation_space = FloatBox(low=-1., high=1., shape=162)
        self._observation_space = FloatBox(low=-1., high=1., shape=7)

    def get_absolute_sizings(self, action):
        # do not use starting points, currently do not round
        # pdb.set_trace()
        absolute_sizings = self.TwoStageAmpEnv.sizing_lower + (
                    self.TwoStageAmpEnv.sizing_upper - self.TwoStageAmpEnv.sizing_lower) * (action + 1) / 2
        # absolute_sizings = self.round_sizings(absolute_sizings)
        # print(absolute_sizings)
        sizing_dict = dict()
        for i, varnames in enumerate(list(self.TwoStageAmpEnv.sizing.keys())):
            sizing_dict[varnames] = absolute_sizings[i]
            # eg1 = [absolute_sizings[i]]

        return sizing_dict

    def step(self, action):
        # pdb.set_trace()
        absolute_sizings = self.get_absolute_sizings(action)
        states, reward, episode_finish, ckt_perf, _ = self.TwoStageAmpEnv.step(absolute_sizings,
                                                                               global_stp=self.global_stp)
        self.ckt_perf = ckt_perf
        # raw states is a list of size 360
        ckt_perf.update(absolute_sizings)
        # self.TwoStageAmpEnv.write_logs(ckt_perf, self.global_stp)
        # self.writer.add_scalar('gain', 1, self.global_stp)
        process = self.kwargs['corner']['process']
        vdd = float(self.kwargs['corner']['vdd'])
        vdd = np.atleast_1d(vdd)
        temp = float(self.kwargs['corner']['temp'])
        temp = np.atleast_1d(temp)
        # normalize pvt
        proc_index = self.processes.index(process)
        process = np.eye(len(self.processes), dtype=float)[proc_index]
        vdd = -1 + 2 * (vdd - self.vddmin) / (self.vddmax - self.vddmin)
        temp = -1 + 2 * (temp - self.tempmin) / (self.tempmax - self.tempmin)
        # pdb.set_trace()
        observation = np.concatenate((process, vdd))
        observation = np.concatenate((observation, temp))
        observation = np.float32(observation)
        # observation = np.float32(np.array(self._init_obs))
        # raw reward is of np float64
        reward = np.float32(reward)
        # raw episode_finish is bool
        episode_finish = True
        # raw all_info is a dict of size 7
        env_info = EnvInfo(timeout=False, taskID=self.taskID)
        # self.TwoStageAmpEnv.write_logs(all_info=all_info)
        self.global_stp += self.log_interval
        # pdb.set_trace()
        return EnvStep(observation=observation, reward=reward, done=episode_finish, env_info=env_info)

    def reset(self):
        # init_obs = np.float32(np.array(self._init_obs))
        # return init_obs.copy()
        process = self.corner['process']
        vdd = float(self.corner['vdd'])
        vdd = np.atleast_1d(vdd)
        temp = float(self.corner['temp'])
        temp = np.atleast_1d(temp)
        # normalize pvt
        proc_index = self.processes.index(process)
        process = np.eye(len(self.processes), dtype=float)[proc_index]
        vdd = -1 + 2 * (vdd - self.vddmin) / (self.vddmax - self.vddmin)
        temp = -1 + 2 * (temp - self.tempmin) / (self.tempmax - self.tempmin)
        # pdb.set_trace()
        observation = np.concatenate((process, vdd))
        observation = np.concatenate((observation, temp))
        observation = np.float32(observation)
        return observation.copy()

    def write_logs(self, state, global_stp) -> None:
        # write per episode
        # write performance
        # pdb.set_trace()
        print(state)
        print(global_stp)
        self.writer.add_scalar('gain', state['gain'], global_stp)
        self.writer.add_scalar('ibias', state['ibias'], global_stp)
        self.writer.add_scalar('phm', state['phm'], global_stp)
        self.writer.add_scalar('ugbw', state['ugbw'], global_stp)

        return None

    def get_ckt_perf(self):
        return self.ckt_perf

    @property
    def action_space(self):
        return self._action_space

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def spaces(self):
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )

    @property
    def horizon(self):
        pass

    def close(self):
        self.writer.close()