import os
import uuid
import subprocess
import cmath
import re
# import paramiko
import time
import multiprocessing
import pdb

import tensorboardX

import numpy as np

# from utils.circuits_utils import get_mos_netlist_static_feature
from envs.StrongArm_Latch.strongArm_latch_py import strongArm_simulator_py

from envs.BaseEnv import BaseEnv
from copy import deepcopy

import yaml
import yamlordereddictloader
from envs.StrongArm_Latch import utils
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
netlistPath = dir_path + "/StrongArm_Latch/NETLIST/ADC_COMPARATOR_GOLDEN_TB/spectre/schematic/netlist"
path = dir_path + "/StrongArm_Latch/SA_Latch_py.yml"
with open(path) as f:
    config = yaml.load(f, Loader=yamlordereddictloader.Loader)
    config = utils.wrap_config(config)


__all__ = ['strongArmEnv']


class strongArmEnv_py_log(BaseEnv):


    def __init__(self, tb_writer: tensorboardX.SummaryWriter, kwargs, need_initial_states=False):
        self.kwargs = kwargs

        self.tech = kwargs['tech']
        self.root_folder = kwargs['root_folder']
        self.envPath = kwargs['runs_dir']
        self.corner = kwargs['corner']

        dstPath = dir_path + '/StrongArm_Latch/EVAL_ENGINE/ocn_files/' + self.envPath + '/netlist/'
        #cpcmd = ['cp', '-avr', netlistPath, dstPath]
        cpcmd = "cp -avr %s %s" % (netlistPath, dstPath)
        curEnvPath = dir_path + "/StrongArm_Latch/EVAL_ENGINE/ocn_files/"  + self.envPath
        self.curEnvPath = curEnvPath
        mkdircmd = "mkdir %s" % (curEnvPath)
        #ret1 = subprocess.run(mkdircmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        #ret2 = subprocess.run(cpcmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        p1 = subprocess.Popen(mkdircmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        (stdout, stderr) = p1.communicate()
        p2 = subprocess.Popen(cpcmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        (stdout, stderr) = p2.communicate()
        # pdb.set_trace()
        ub = list()
        lb = list()
        for varname in list(config["des_vars"].keys()):
            lb.append(config["des_vars"][varname][0][0])
            ub.append(config["des_vars"][varname][0][1])
        self.sizing_upper = np.array(ub)
        self.sizing_lower = np.array(lb)

        self.all_perf_max = np.array([8.00000000e+00, 5e-04, 6e-04, 2.5e+02, 5e-04, 5e-04, 8e-02, 1.5e-03])
        self.all_perf_min = np.array([4.00000000e+00, 4e-06, 3e-06, 6e-03, 1e-06, 1e-06, 1e-08, 1e-04])
        self.config = config
        # pdb.set_trace()
        self.simulator = strongArm_simulator_py(self.config, curEnvPath)
        self.writer = tb_writer
        self.episode = 0
        self.global_stp = 0
        self.unsaturation_reward = kwargs['unsaturation_reward']
        self.nb_components = 15
        self.step_ctr = 0
        self.unsatu_step_ctr = 0
        self.nb_actions = kwargs['nb_actions']
        self.nb_states = kwargs['nb_states']
        self.state_type = kwargs['state_type']
        self.step_per_episode = kwargs['step_per_episode']
        self.runs_dir = kwargs['runs_dir']
        self.log_scale = kwargs['log_scale']
        self.fix_step_per_episode = kwargs['fix_step_per_episode']
        self.max_step_per_episode = kwargs['max_step_per_episode']
        self.max_unsatu_step_per_episode = kwargs['max_unsatu_step_per_episode']
        self.round_method = kwargs['round_method']
        self.frac = np.zeros(8)
        self.model_feature_dict = kwargs['model_feature_dict']

        self.components_type = {
            'M0': 'pmos',
            'M1': 'pmos',
            'M2': 'pmos',
            'M3': 'pmos',
            'M4': 'pmos',
            'M5': 'pmos',
            'M6': 'nmos',
            'M7': 'nmos',
            'M8': 'nmos',
            'M9': 'nmos',
            'M10': 'nmos',
            'C0': 'c',
            'C1': 'c',
            'C2': 'c',
            'C3': 'c',
        }

        self.components = ['M0', 'M1', 'M2', 'M3', 'M4', 'M5', 'M6', 'M7',
                           'M8', 'M9', 'M10', 'C0', 'C1', 'C2', 'C3']

        # self.range = {
        #     'Power': 1.5e-5,
        #     'delay': 10e-9,
        #     'reset': 6.5e-9,
        #     'area': 2.6e-11,
        #     'reset_val': 0,
        #     'rise_val': 1.2,
        #     'input_ref_noise': 5e-5
        # }
        ### according to real golden comparator tb @ 1MHz fs ###
        self.range = {
            'Power': 4.5e-6,
            'delay': 10e-9 * 1.4,
            'reset': 7e-9 * 1.4,
            'area': 2.6e-11 * 10,
            'reset_val': 0,
            'rise_val': float(self.corner['vdd']),
            'input_ref_noise': 5e-5 * 1.2,
        }

        if need_initial_states:
            self._initial_states = self.step(np.zeros(self.nb_actions), 0, 0, real_step=False)


    def step(self, action, episode, global_stp=1, real_step=True):
        self.episode = episode
        self.global_stp = global_stp

        all_info = {}
        all_info['actions'] = action
        # pdb.set_trace()
        all_info['global_stp'] = global_stp
        all_info['episode'] = episode

        # get absolute sizings
        absolute_sizings = self.get_absolute_sizings(action, config)

        all_info['absolute_sizings'] = absolute_sizings
        # print(absolute_sizings)
        # simulate
        sim_info = self.sim(absolute_sizings, self.corner)
        all_info.update(sim_info)
        # pdb.set_trace()

        fom, fracs = self.get_fom(all_info['metrics'])
        all_info['fom'] = fom
        all_info['fracs'] = fracs

        # get states
        states = self.get_states(all_info, self.state_type)

        # this is used for initial states generation
        if not real_step:
            return states

        # get reward
        reward = self.get_reward(all_info)

        # determine if the episode if finished
        episode_finish = self.get_episode_finish(reward)
        print("reward=", reward)
        print('corner=', '%s_%s_%s' % (
            self.kwargs['corner']['process'], self.kwargs['corner']['temp'], self.kwargs['corner']['vdd']))
        print(all_info['metrics'])
        # print(all_info['absolute_sizings']['w1'])
        # pdb.set_trace()
        # write log
        # self.write_logs(all_info, episode_finish)
        # reset ctr and cumulative sample
        if episode_finish:
            self.step_ctr = 0
            self.unsatu_step_ctr = 0
            # self.cumulative_sample = deepcopy(self.starting_sample)

        return states, reward, episode_finish, all_info


    def sim(self, sizings, corner):
        # perf is a dict of simulation responses
        # pdb.set_trace()
        perf = self.simulator.simulate(sizings, corner=self.corner)
        # pdb.set_trace()
        info = {}
        info['metrics'] = perf

        return info

    def get_states(self, all_info: dict, state_type: str) -> list:
        if state_type == 'one_hot_idx_and_feat':
            states = []
            for i in range(self.nb_components):
                state = []
                # one hot index
                state.extend(np.eye(self.nb_components)[i].tolist())

                # one-hot type
                one_hot_types = {'nmos': [1, 0, 0, 0], 'pmos': [0, 1, 0, 0], 'r': [0, 0, 1, 0], 'c': [0, 0, 0, 1]}
                features = {'nmos': [1.8e-07, 6.86e+17, 1.95e-08, 5e-9, 113332.6], 'pmos': [1.6e-07, 6.28e+17, 1.25e-08, 1.15e-08, 94000], 'r': [0] * 5, 'c': [0] * 5}
                state.extend(one_hot_types[self.components_type[self.components[i]]])
                state.extend(features[self.components_type[self.components[i]]])
                states.append(state)
            # print(states)
            # normalize states:
            states_mean = np.mean(states, axis=0)
            states_std = np.std(states, axis=0)
            states_std[abs(states_std) < 1e-12] = 1.
            states = (states - states_mean) / states_std
            # print(states)
            states = states.flatten().tolist()
            return states

        elif state_type == 'idx_and_feat':
            states = []
            for i in range(self.nb_components):
                state = []
                # index
                state.extend([i])
                one_hot_types = {'nmos': [1, 0, 0, 0], 'pmos': [0, 1, 0, 0], 'r': [0, 0, 1, 0], 'c': [0, 0, 0, 1]}
                features = {'nmos': [1.8e-07, 6.86e+17, 1.95e-08, 5e-9, 113332.6], 'pmos': [1.6e-07, 6.28e+17, 1.25e-08, 1.15e-08, 94000], 'r': [0] * 5, 'c': [0] * 5}
                state.extend(one_hot_types[self.components_type[self.components[i]]])
                state.extend(features[self.components_type[self.components[i]]])
                states.append(state)


            # normalize states:
            states_mean = np.mean(states, axis=0)
            states_std = np.std(states, axis=0)
            states_std[abs(states_std) < 1e-12] = 1.
            states = (states - states_mean) / states_std
            # print(states)
            states = states.flatten().tolist()
            return states

        elif state_type == 'idx_and_simple_feat':
            states = []
            for i in range(self.nb_components):
                state = []
                # index
                state.extend([i])
                simple_feat = {'nmos': [0, 0], 'pmos': [0, 1], 'r': [1, 0], 'c': [1, 1]}
                state.extend(simple_feat[self.components_type[self.components[i]]])
                states.append(state)

            # normalize states:
            states_mean = np.mean(states, axis=0)
            states_std = np.std(states, axis=0)
            states_std[abs(states_std) < 1e-12] = 1.
            states = (states - states_mean) / states_std
            # print(states)
            states = states.flatten().tolist()
            return states

        return [0] * self.nb_states

    def write_logs(self, all_info: dict, episode_finish=False) -> None:
        # write per episode
        if bool(all_info['metrics']):
            # write performance metrics
            corner_info = '%s_%s_%s_' % (self.corner['process'], self.corner['temp'], self.corner['vdd'])
            self.writer.add_scalar(corner_info+'Power', all_info['metrics']['Power'], self.global_stp)
            self.writer.add_scalar(corner_info+'delay', all_info['metrics']['delay'], self.global_stp)
            self.writer.add_scalar(corner_info+'reset', all_info['metrics']['reset'], self.global_stp)
            self.writer.add_scalar(corner_info+'input_ref_noise', all_info['metrics']['input_ref_noise'], self.global_stp)
            self.writer.add_scalar(corner_info+'reset_val', all_info['metrics']['reset_val'], self.global_stp)
            self.writer.add_scalar(corner_info+'rise_val', all_info['metrics']['rise_val'], self.global_stp)
            self.writer.add_scalar(corner_info+'area', all_info['metrics']['area'], self.global_stp)
            self.writer.add_scalar(corner_info+'Cl_finger', np.array(all_info['absolute_sizings']['Cl_finger']), self.global_stp)
            # write sizings

            # test_dict = {'sizing_of_' + key: np.squeeze(np.array(value)) for key, value in all_info['absolute_sizings'].items() \
            #              if key.startswith('l')}
            # pdb.set_trace()
            invalid = {"Cl_finger", "l1", "l2", "l3", "l4", "l6", "l8"}
            self.writer.add_scalars(corner_info+'sizings', {'sizing_of_' + key: np.array(all_info['absolute_sizings'][key]) for key in
                                                               all_info['absolute_sizings'] if key not in invalid}, self.global_stp)

            # self.writer.add_scalars('test', {'t1': 1, 't2': 2}, self.global_stp)
            # write fracs
            # self.writer.add_scalar('frac_Power', all_info['fracs']['Power'], self.global_stp)
            # self.writer.add_scalar('frac_noise', all_info['fracs']['noise'], self.global_stp)
            # self.writer.add_scalar('frac_rise', all_info['fracs']['rise'], self.global_stp)
            # self.writer.add_scalar('frac_reset', all_info['fracs']['reset'], self.global_stp)
            self.writer.add_scalars(corner_info+'fracs', {'frac_of_' + key: np.array(all_info['fracs'][key]) for key in
                                     all_info['fracs']}, self.global_stp)

            self.writer.add_scalar(corner_info+'fom', all_info['fom'], self.global_stp)

            # track std of action distribution
            # self.writer.add_scalar('action_noise_tracker', self.kwargs['delta_decay'] ** max(0, (self.global_stp - 128)), self.global_stp)
            # debug = 0.99 ** max(0, (self.global_stp - 1))
            # pdb.set_trace()
        return None

    def get_fom(self, metrics:dict):

        # frac = 0

        metricIsEmpty = not metrics
        if metricIsEmpty:
            return -2, 0

        # settling requirement
        reset_val = metrics['reset_val']
        rise_val = metrics['rise_val']
        dReset = reset_val - self.range['reset_val']
        dRise = self.range['rise_val'] - rise_val
        # dReset [0, 1.2]
        # dRise [0, 1.2]
        # prototype: (x-1) / (x+1)
        # reset_component = min(-0.05, -dReset)

        ### modified on 5-10 ###
        # reset_x = -0.87 * dReset + 1.044
        # reset_component = min(0, (reset_x-1)/ (reset_x+1))
        # rise_x = -0.87 * dRise + 1.044
        # rise_component = min(0, (rise_x-1)/ (rise_x+1))

        reset_component = -(dReset - 0.05) / (dReset + 0.05)
        reset_component = min(0, reset_component)
        rise_component = -(dRise - 0.05) / (dRise + 0.05)
        rise_component = min(0, rise_component)

        # noise requirement
        # input_ref_noise [50uV, 550uV)
        # input_ref_noise [range, 11*range)
        input_ref_noise = metrics['input_ref_noise']
        # noise_component = min(1, self.range['input_ref_noise'] / input_ref_noise)
        ### modified on 5-10 ###
        # noise_x = -0.1 * (1/self.range['input_ref_noise']) * input_ref_noise + 1.1
        # noise_component = max(-2, min(0, (noise_x-1)/(noise_x+1)))
        noise_component = -(input_ref_noise-self.range['input_ref_noise']) / (input_ref_noise + self.range['input_ref_noise'])
        noise_component = max(-2, min(0, noise_component))

        # switch between two reward functions according to current situation
        # power [epsilon, inf)
        Power = metrics['Power']
        # power_component = self.range['Power'] / Power
        # power_component = -Power / self.range['Power']
        power_component = -(Power - self.range['Power']) / (Power + self.range['Power'])

        delay = metrics['delay']
        delay_component = -(delay-self.range['delay']) / (delay + self.range['delay'])
        delay_component = min(0, delay_component)
        reset_delay = metrics['reset']
        reset_delay_component = -(reset_delay-self.range['reset']) / (reset_delay + self.range['reset'])
        reset_delay_component = min(0, reset_delay_component)

        # if reset_component >= 0 and rise_component >= 0 and noise_component >= 0:
        #     reward = 25 * power_component + 40
        # else:
        #     # reward = 0.2 * power_component + reset_component + rise_component + noise_component
        #     reward = reset_component + rise_component + noise_component
        alpha = self.kwargs['alpha_power']
        power_component = alpha * power_component
        power_component = max(-1, min(0, power_component))
        # alpha = 1
        # if reset_component < 0 or rise_component < 0 or noise_component < 0:
        #     reward = reset_component + rise_component + noise_component + alpha * power_component
        # else:
        #     reward = 10 * alpha * power_component + 3

        ### modified on 5-10 ###
        # reward = reset_component + rise_component + noise_component + power_component
        reward = reset_component + rise_component + noise_component + delay_component + reset_delay_component + power_component
        if reward >= -0.02:
            reward = 0.2

        fracs = {'Power': power_component,
                'noise': noise_component,
                'reset': reset_component,
                'rise' : rise_component,
                'delay': delay_component,
                'reset_delay': reset_delay_component
                }

        return reward, fracs

    def get_path(self):
        return self.curEnvPath

    def get_absolute_sizings(self, action, config):
        # do not use starting points, currently do not round
        # pdb.set_trace()

        absolute_sizings = self.sizing_lower + (self.sizing_upper - self.sizing_lower) * (action + 1) / 2
        # absolute_sizings = self.round_sizings(absolute_sizings)
        # print(absolute_sizings)
        sizing_dict = dict()
        for i, varnames in enumerate(list(config["des_vars"].keys())):
            sizing_dict[varnames] = absolute_sizings[i]
            # eg1 = [absolute_sizings[i]]

        return sizing_dict

    @property
    def initial_states(self):
        return self._initial_states
