import os
import uuid
import subprocess
import cmath
import re
# import paramiko
import time
import multiprocessing
import pdb

import tensorboardX

import numpy as np

# from utils.circuits_utils import get_mos_netlist_static_feature
from envs.FoldedCascode.foldedCascode_py import foldedCascode_simulator_py

from envs.BaseEnv import BaseEnv
from copy import deepcopy

import yaml
import yamlordereddictloader
from envs.FoldedCascode import utils
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
netlistPath_cls = dir_path + "/FoldedCascode/NETLIST/Folded_cascode_two_stage_closed_loop_TB/spectre/schematic/netlist"
netlistPath_open = dir_path + "/FoldedCascode/NETLIST/Folded_cascode_two_stage_TB/spectre/schematic/netlist"
path = dir_path + "/FoldedCascode/ota_foldedcascode_func_RSAS_py.yml"
with open(path) as f:
    config = yaml.load(f, Loader=yamlordereddictloader.Loader)
    config = utils.wrap_config(config)


__all__ = ['foldedCascodeEnv']

class foldedCascodeEnv_py_log(BaseEnv):
    def __init__(self, tb_writer: tensorboardX.SummaryWriter, kwargs, need_initial_states=False):
        self.kwargs = kwargs

        self.tech = kwargs['tech']
        self.root_folder = kwargs['root_folder']
        self.envPath = kwargs['runs_dir']
        self.corner = kwargs['corner']
        # copy netlists
        dstPath_cls = dir_path + '/FoldedCascode/MATLAB_ENGINE/ocn_files/' + self.envPath + '/netlist_cls/'
        dstPath_open = dir_path + '/FoldedCascode/MATLAB_ENGINE/ocn_files/' + self.envPath + '/netlist_open/'
        cpcmd1 = "cp -avr %s %s" % (netlistPath_cls, dstPath_cls)
        cpcmd2 = "cp -avr %s %s" % (netlistPath_open, dstPath_open)
        curEnvPath = dir_path + "/FoldedCascode/MATLAB_ENGINE/ocn_files/" + self.envPath
        self.curEnvPath = curEnvPath
        mkdircmd = "mkdir %s" % (curEnvPath)

        p1 = subprocess.Popen(mkdircmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        (stdout, stderr) = p1.communicate()
        p2 = subprocess.Popen(cpcmd1, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        (stdout, stderr) = p2.communicate()
        p3 = subprocess.Popen(cpcmd2, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        (stdout, stderr) = p3.communicate()
        # pdb.set_trace()
        # read in upper/lower bound
        ub = list()
        lb = list()
        for varname in list(config["des_vars"].keys()):
            lb.append(config["des_vars"][varname][0][0])
            ub.append(config["des_vars"][varname][0][1])
        self.sizing_upper = np.array(ub)
        self.sizing_lower = np.array(lb)

        self.config = config
        self.simulator = foldedCascode_simulator_py(self.config, curEnvPath)
        self.writer = tb_writer
        self.episode = 0
        self.global_stp = 0
        self.unsaturation_reward = kwargs['unsaturation_reward']
        self.nb_components = 9
        self.step_ctr = 0
        self.unsatu_step_ctr = 0
        self.nb_actions = kwargs['nb_actions']
        self.nb_states = kwargs['nb_states']
        self.state_type = kwargs['state_type']
        self.step_per_episode = kwargs['step_per_episode']
        self.runs_dir = kwargs['runs_dir']
        self.log_scale = kwargs['log_scale']
        self.fix_step_per_episode = kwargs['fix_step_per_episode']
        self.max_step_per_episode = kwargs['max_step_per_episode']
        self.max_unsatu_step_per_episode = kwargs['max_unsatu_step_per_episode']
        self.round_method = kwargs['round_method']
        self.frac = np.zeros(8) #???
        self.model_feature_dict = kwargs['model_feature_dict']

        # set components
        self.components_type = {
            'M1': 'pmos',
            'M2': 'pmos',
            'M3': 'pmos',
            'M4': 'pmos',
            'M5': 'pmos',
            'M6': 'pmos',
            'M7': 'nmos',
            'MCAP': 'c',
            'Cf' : 'c'
        }

        self.components = ['M1', 'M2', 'M3', 'M4', 'M5', 'M6', 'M7',
                          'MCAP', 'Cf']
        # set specs
        self.range = {
            'power': 1e-3,
            ### set a loose requirement for output swing ###
            'output_swing': 0.5,
            'gain': 60,
            'cmrr': 80,
            'psrr': 80,
            'pm_dm': 60,
            'rms_noise_out_dm':3e-4,
            'rise_time':30e-9,
            'static_error': 0.1,
            'lg_ugb':30e6,
        }

        if need_initial_states:
            self._initial_states = self.step(np.zeros(self.nb_actions), 0, 0, real_step=False)

    def step(self, action, episode, global_stp=1, real_step=True):
        self.episode = episode
        self.global_stp = global_stp

        all_info = {}
        all_info['actions'] = action
        # pdb.set_trace()
        all_info['global_stp'] = global_stp
        all_info['episode'] = episode

        # get absolute sizings
        absolute_sizings = self.get_absolute_sizings(action, config)
        all_info['absolute_sizings'] = absolute_sizings
        # print(absolute_sizings)
        # simulate
        sim_info = self.sim(absolute_sizings, self.corner)
        all_info.update(sim_info)

        fom, fracs = self.get_fom(all_info['metrics'])
        all_info['fom'] = fom
        all_info['fracs'] = fracs

        # get states
        states = self.get_states(all_info, self.state_type)

        # this is used for initial states generation
        if not real_step:
            return states

        # get reward
        reward = self.get_reward(all_info)

        # determine if the episode if finished
        episode_finish = self.get_episode_finish(reward)
        print("reward=", reward)
        print('corner=', '%s_%s_%s' % (self.kwargs['corner']['process'], self.kwargs['corner']['temp'],self.kwargs['corner']['vdd']))
        print(all_info['metrics'])
        # print(all_info['absolute_sizings']['w1'])
        # pdb.set_trace()
        # write log
        # self.write_logs(all_info, episode_finish)
        # reset ctr and cumulative sample
        if episode_finish:
            self.step_ctr = 0
            self.unsatu_step_ctr = 0
            # self.cumulative_sample = deepcopy(self.starting_sample)

        return states, reward, episode_finish, all_info


    def sim(self, sizings, corner):
        # perf is a dict of simulation responses
        # pdb.set_trace()
        perf = self.simulator.simulate(sizings, corner=self.corner)
        # pdb.set_trace()
        info = {}
        info['metrics'] = perf

        return info

    def get_states(self, all_info: dict, state_type: str) -> list:
        if state_type == 'one_hot_idx_and_feat':
            states = []
            for i in range(self.nb_components):
                state = []
                # one hot index
                state.extend(np.eye(self.nb_components)[i].tolist())

                # one-hot type
                one_hot_types = {'nmos': [1, 0, 0, 0], 'pmos': [0, 1, 0, 0], 'r': [0, 0, 1, 0], 'c': [0, 0, 0, 1]}
                features = {'nmos': [1.8e-07, 6.86e+17, 1.95e-08, 5e-9, 113332.6], 'pmos': [1.6e-07, 6.28e+17, 1.25e-08, 1.15e-08, 94000], 'r': [0] * 5, 'c': [0] * 5}
                state.extend(one_hot_types[self.components_type[self.components[i]]])
                state.extend(features[self.components_type[self.components[i]]])
                states.append(state)
            # print(states)
            # normalize states:
            states_mean = np.mean(states, axis=0)
            states_std = np.std(states, axis=0)
            states_std[abs(states_std) < 1e-12] = 1.
            states = (states - states_mean) / states_std
            # print(states)
            states = states.flatten().tolist()
            return states

        elif state_type == 'idx_and_feat':
            states = []
            for i in range(self.nb_components):
                state = []
                # index
                state.extend([i])
                one_hot_types = {'nmos': [1, 0, 0, 0], 'pmos': [0, 1, 0, 0], 'r': [0, 0, 1, 0], 'c': [0, 0, 0, 1]}
                features = {'nmos': [1.8e-07, 6.86e+17, 1.95e-08, 5e-9, 113332.6], 'pmos': [1.6e-07, 6.28e+17, 1.25e-08, 1.15e-08, 94000], 'r': [0] * 5, 'c': [0] * 5}
                state.extend(one_hot_types[self.components_type[self.components[i]]])
                state.extend(features[self.components_type[self.components[i]]])
                states.append(state)


            # normalize states:
            states_mean = np.mean(states, axis=0)
            states_std = np.std(states, axis=0)
            states_std[abs(states_std) < 1e-12] = 1.
            states = (states - states_mean) / states_std
            # print(states)
            states = states.flatten().tolist()
            return states

        elif state_type == 'idx_and_simple_feat':
            states = []
            for i in range(self.nb_components):
                state = []
                # index
                state.extend([i])
                simple_feat = {'nmos': [0, 0], 'pmos': [0, 1], 'r': [1, 0], 'c': [1, 1]}
                state.extend(simple_feat[self.components_type[self.components[i]]])
                states.append(state)

            # normalize states:
            states_mean = np.mean(states, axis=0)
            states_std = np.std(states, axis=0)
            states_std[abs(states_std) < 1e-12] = 1.
            states = (states - states_mean) / states_std
            # print(states)
            states = states.flatten().tolist()
            return states

        return [0] * self.nb_states

    def write_logs(self, all_info: dict, episode_finish=False) -> None:
        # write per episode
        if bool(all_info['metrics']):
            # write performance metrics
            corner_info = '%s_%s_%s_' % (self.corner['process'], self.corner['temp'], self.corner['vdd'])
            self.writer.add_scalar(corner_info+'power', all_info['metrics']['power'], self.global_stp)
            # self.writer.add_scalar(corner_info+'output_swing', all_info['metrics']['output_swing'], self.global_stp)
            self.writer.add_scalar(corner_info+'gain', all_info['metrics']['gain'], self.global_stp)
            self.writer.add_scalar(corner_info+'cmrr', all_info['metrics']['cmrr'], self.global_stp)
            self.writer.add_scalar(corner_info+'psrr', all_info['metrics']['psrr'], self.global_stp)
            self.writer.add_scalar(corner_info+'noise', all_info['metrics']['rms_noise_out_dm'], self.global_stp)
            # self.writer.add_scalar(corner_info+'rise_time', all_info['metrics']['rise_time'], self.global_stp)
            # self.writer.add_scalar(corner_info+'static_error', all_info['metrics']['static_error'], self.global_stp)
            self.writer.add_scalar(corner_info+'lg_ugb', all_info['metrics']['lg_ugb'], self.global_stp)
            self.writer.add_scalar(corner_info+'pm_dm', all_info['metrics']['pm_dm'], self.global_stp)

            self.writer.add_scalar(corner_info+'MCAP', np.array(all_info['absolute_sizings']['MCAP']), self.global_stp)
            self.writer.add_scalar(corner_info+'Cf', np.array(all_info['absolute_sizings']['Cf']), self.global_stp)
            # write sizings

            # test_dict = {'sizing_of_' + key: np.squeeze(np.array(value)) for key, value in all_info['absolute_sizings'].items() \
            #              if key.startswith('l')}
            # pdb.set_trace()
            invalid = {"MCAP", "Cf"}
            self.writer.add_scalars(corner_info+'sizings', {'sizing_of_' + key: np.array(all_info['absolute_sizings'][key]) for key in
                                                               all_info['absolute_sizings'] if key not in invalid}, self.global_stp)

            # self.writer.add_scalars('test', {'t1': 1, 't2': 2}, self.global_stp)
            # write fracs
            # self.writer.add_scalar('frac_Power', all_info['fracs']['Power'], self.global_stp)
            # self.writer.add_scalar('frac_noise', all_info['fracs']['noise'], self.global_stp)
            # self.writer.add_scalar('frac_rise', all_info['fracs']['rise'], self.global_stp)
            # self.writer.add_scalar('frac_reset', all_info['fracs']['reset'], self.global_stp)
            self.writer.add_scalars(corner_info+'fracs', {'frac_of_' + key: np.array(all_info['fracs'][key]) for key in
                                     all_info['fracs']}, self.global_stp)

            self.writer.add_scalar(corner_info+'fom', all_info['fom'], self.global_stp)

            # track std of action distribution
            # self.writer.add_scalar('action_noise_tracker', self.kwargs['delta_decay'] ** max(0, (self.global_stp - 128)), self.global_stp)
            # debug = 0.99 ** max(0, (self.global_stp - 1))
            # pdb.set_trace()
        return None

    def get_fom(self, metrics:dict):
        # pdb.set_trace()
        if not metrics:
            return -4, 0
        components = {}
        for key in {'gain', 'cmrr', 'psrr', 'pm_dm', 'lg_ugb'}:
            components[key] = (metrics[key] - self.range[key]) / (metrics[key] + self.range[key])
            components[key] = min(0, components[key])

        # for key in {'rms_noise_out_dm', 'static_error', 'rise_time'}:
        #     component_x = -0.1 * (1/self.range[key]) * metrics[key] + 1.1
        #     component_x = max(0, component_x)
        #     components[key] = (component_x - 1) / (component_x + 1)
        #     components[key] = min(0, components[key])

        # for key in {'rms_noise_out_dm', 'static_error', 'rise_time'}:
        for key in {'rms_noise_out_dm', 'power'}:
            metrics[key] = max(0, metrics[key])
            components[key] = -(metrics[key] - self.range[key]) / (metrics[key] + self.range[key])
            components[key] = min(0, components[key])

        # components['power'] = -(metrics['power'] - self.range['power']) / (metrics['power'] + self.range['power'])

        reward = 0
        for key, value in components.items():
            reward += value

        if reward >= -0.02:
            reward = 0.2

        fracs = {}
        for key in components:
            fracs[key] = components[key]

        return reward, fracs

    def get_path(self):
        return self.curEnvPath

    def get_absolute_sizings(self, action, config):
        # do not use starting points, currently do not round
        # pdb.set_trace()

        absolute_sizings = self.sizing_lower + (self.sizing_upper - self.sizing_lower) * (action + 1) / 2
        # absolute_sizings = self.round_sizings(absolute_sizings)
        # print(absolute_sizings)
        sizing_dict = dict()
        for i, varnames in enumerate(list(config["des_vars"].keys())):
            sizing_dict[varnames] = absolute_sizings[i]
            # eg1 = [absolute_sizings[i]]

        return sizing_dict

    @property
    def initial_states(self):
        return self._initial_states