"""
A new ckt environment based on a new structure of MDP
"""
import gym
from gym import spaces

import numpy as np
import random
import psutil

from multiprocessing.dummy import Pool as ThreadPool
from collections import OrderedDict
import yaml
import yaml.constructor
import statistics
import os
import IPython
import itertools
from envs.NGspiceOpamp.eval_engines.util.core import *
import pickle
import os
import uuid

from envs.NGspiceOpamp.eval_engines.ngspice.TwoStageClass import *
import pdb
#way of ordering the way a yaml file is read
class OrderedDictYAMLLoader(yaml.Loader):
    """
    A YAML loader that loads mappings into ordered dictionaries.
    """

    def __init__(self, *args, **kwargs):
        yaml.Loader.__init__(self, *args, **kwargs)

        self.add_constructor(u'tag:yaml.org,2002:map', type(self).construct_yaml_map)
        self.add_constructor(u'tag:yaml.org,2002:omap', type(self).construct_yaml_map)

    def construct_yaml_map(self, node):
        data = OrderedDict()
        yield data
        value = self.construct_mapping(node)
        data.update(value)

    def construct_mapping(self, node, deep=False):
        if isinstance(node, yaml.MappingNode):
            self.flatten_mapping(node)
        else:
            raise yaml.constructor.ConstructorError(None, None,
                                                    'expected a mapping node, but found %s' % node.id, node.start_mark)

        mapping = OrderedDict()
        for key_node, value_node in node.value:
            key = self.construct_object(key_node, deep=deep)
            value = self.construct_object(value_node, deep=deep)
            mapping[key] = value
        return mapping

class TwoStageAmp(gym.Env):
    metadata = {'render.modes': ['human']}

    PERF_LOW = -1
    PERF_HIGH = 0

    #obtains yaml file
    # path = os.getcwd()
    path = os.path.dirname(os.path.abspath(__file__))
    # pdb.set_trace()
    CIR_YAML = path+'/eval_engines/ngspice/ngspice_inputs/yaml_files/two_stage_opamp.yaml'
    corner_normal = {
        'process': 'tt',
        'temp': '27',
        'vdd': '1.2'
    }

    def __init__(self, env_config, tb_writer):
        self.multi_goal = env_config.get("multi_goal",False)
        self.generalize = env_config.get("generalize",False)
        num_valid = env_config.get("num_valid",50)
        self.specs_save = env_config.get("save_specs", False)
        self.valid = env_config.get("run_valid", False)
        self.corner = env_config.get("corner", TwoStageAmp.corner_normal)
        self.log = env_config.get("log", False)
        # pdb.set_trace()
        self.env_steps = 0
        self.writer = tb_writer
        with open(TwoStageAmp.CIR_YAML, 'r') as f:
            yaml_data = yaml.load(f, OrderedDictYAMLLoader)

        # design specs
        if self.generalize == False:
            specs = yaml_data['single_target_specs']
        else:
            load_specs_path = TwoStageAmp.path+"/gen_specs/ngspice_specs_gen_two_stage_opamp"
            with open(load_specs_path, 'rb') as f:
                specs = pickle.load(f)
            # pdb.set_trace()
        self.specs = OrderedDict(sorted(specs.items(), key=lambda k: k[0]))
        if self.specs_save:
            with open("specs_"+str(num_valid)+str(random.randint(1,100000)), 'wb') as f:
                pickle.dump(self.specs, f)
        
        self.specs_ideal = []
        self.specs_id = list(self.specs.keys())
        self.fixed_goal_idx = -1 
        self.num_os = len(list(self.specs.values())[0])
        
        # param array
        params = yaml_data['params']
        self.params = []
        self.params_id = list(params.keys())

        self.sizing = params
        ub = list()
        lb = list()
        for varname in list(params.keys()):
            lb.append(params[varname][0])
            ub.append(params[varname][1])
        self.sizing_upper = np.array(ub)
        self.sizing_lower = np.array(lb)


        for value in params.values():
            param_vec = np.arange(value[0], value[1], value[2])
            self.params.append(param_vec)
        # pdb.set_trace()
        #initialize sim environment
        instance_path = str(uuid.uuid4())
        # root_dir = "/tmp/ckt_da/" + instance_path
        current_fpath = os.path.realpath(__file__)
        parent_path = os.path.abspath(os.path.join(current_fpath, os.pardir))
        root_dir = parent_path + "/ckt_da/" + instance_path
        # pdb.set_trace()
        self.sim_env = TwoStageClass(yaml_path=TwoStageAmp.CIR_YAML, num_process=1, path=TwoStageAmp.path,
                                     corner=self.corner, root_dir=root_dir)
        self.action_meaning = [-1,0,2] 
        self.action_space = spaces.Tuple([spaces.Discrete(len(self.action_meaning))]*len(self.params_id))
        #self.action_space = spaces.Discrete(len(self.action_meaning)**len(self.params_id))
        self.observation_space = spaces.Box(
            low=np.array([TwoStageAmp.PERF_LOW]*2*len(self.specs_id)+len(self.params_id)*[1]),
            high=np.array([TwoStageAmp.PERF_HIGH]*2*len(self.specs_id)+len(self.params_id)*[1]))

        #initialize current param/spec observations
        self.cur_specs = np.zeros(len(self.specs_id), dtype=np.float32)
        self.cur_params_idx = np.zeros(len(self.params_id), dtype=np.int32)
        # pdb.set_trace()
        #Get the g* (overall design spec) you want to reach
        self.global_g = []
        for spec in list(self.specs.values()):
                self.global_g.append(float(spec[self.fixed_goal_idx]))
        self.g_star = np.array(self.global_g)
        self.global_g = np.array(yaml_data['normalize'])
        
        #objective number (used for validation)
        self.obj_idx = 0

    def reset(self):
        #if multi-goal is selected, every time reset occurs, it will select a different design spec as objective
        if self.generalize == True:
            if self.valid == True:
                if self.obj_idx > self.num_os-1:
                    self.obj_idx = 0
                idx = self.obj_idx
                self.obj_idx += 1
            else:
                idx = random.randint(0,self.num_os-1)
            self.specs_ideal = []
            for spec in list(self.specs.values()):
                self.specs_ideal.append(spec[idx])
            self.specs_ideal = np.array(self.specs_ideal)
            # pdb.set_trace()
        else:
            if self.multi_goal == False:
                self.specs_ideal = self.g_star 
            else:
                idx = random.randint(0,self.num_os-1)
                self.specs_ideal = []
                for spec in list(self.specs.values()):
                    self.specs_ideal.append(spec[idx])
                self.specs_ideal = np.array(self.specs_ideal)
        #print("num total:"+str(self.num_os))

        #applicable only when you have multiple goals, normalizes everything to some global_g
        self.specs_ideal_norm = self.lookup(self.specs_ideal, self.global_g)

        #initialize current parameters
        # self.cur_params_idx = np.array([33, 33, 33, 33, 33, 14, 20])
        # self.cur_specs = self.update(self.cur_params_idx)
        # cur_spec_norm = self.lookup(self.cur_specs, self.global_g)
        # reward = self.reward(self.cur_specs, self.specs_ideal)
        #
        # #observation is a combination of current specs distance from ideal, ideal spec, and current param vals
        # self.ob = np.concatenate([cur_spec_norm, self.specs_ideal_norm, self.cur_params_idx])
        return 0
    def convert_action(self, action_dict):
        order_of_keys = ["mp1", "mn1", "mp3", "mn3", "mn4", "mn5", "cc"]
        list_of_tuples = [(key, action_dict[key]) for key in order_of_keys]
        action = OrderedDict(list_of_tuples)
        return action

    def step(self, action, global_stp):
        """
        :param action: is vector with elements between 0 and 1 mapped to the index of the corresponding parameter
        :return:
        """
        # pdb.set_trace()
        #Take action that RL agent returns to change current params
#         action = list(np.reshape(np.array(action),(np.array(action).shape[0],)))
#         self.cur_params_idx = self.cur_params_idx + np.array([self.action_meaning[a] for a in action])
#
# #        self.cur_params_idx = self.cur_params_idx + np.array(self.action_arr[int(action)])
#         self.cur_params_idx = np.clip(self.cur_params_idx, [0]*len(self.params_id), [(len(param_vec)-1) for param_vec in self.params])
#         #Get current specs and normalize
#         pdb.set_trace()
        # from dict to ordered dict
        action = self.convert_action(action)
        self.cur_specs = self.update(action)
        ckt_perf = self.cur_specs
        self.cur_specs = np.array(list(self.cur_specs.values()))
        cur_spec_norm  = self.lookup(self.cur_specs, self.global_g)
        reward = self.reward(self.cur_specs, self.specs_ideal)
        done = False

        #incentivize reaching goal state
        if (reward >= 10):
            done = True
            print('-'*10)
            print('params = ', action)
            print('specs:', self.cur_specs)
            print('ideal specs:', self.specs_ideal)
            print('re:', reward)
            print('-'*10)

        self.ob = np.concatenate([cur_spec_norm, self.specs_ideal_norm, self.cur_params_idx])
        self.env_steps = global_stp

        if self.log:
            # pdb.set_trace()
            if reward > 0 :
                ckt_perf['pass'] = 1
            else:
                ckt_perf['pass'] = 0
            ckt_perf['reward'] = reward
            ckt_perf.update(action)
            self.write_logs(ckt_perf, global_stp)
        #print('cur ob:' + str(self.cur_specs))
        #print('ideal spec:' + str(self.specs_ideal))
        #print(reward)
        return self.ob, reward, done, ckt_perf, {}

    def lookup(self, spec, goal_spec):
        goal_spec = [float(e) for e in goal_spec]
        norm_spec = (spec-goal_spec)/(goal_spec+spec)
        return norm_spec
    
    def reward(self, spec, goal_spec):
        '''
        Reward: doesn't penalize for overshooting spec, is negative
        '''
        rel_specs = self.lookup(spec, goal_spec)
        pos_val = [] 
        reward = 0.0
        ibias_max = 0
        for i,rel_spec in enumerate(rel_specs):
            # pdb.set_trace()
            if(self.specs_id[i] == 'ibias_max'):
                rel_spec = rel_spec*-1.0#/10.0
                ibias_max = rel_spec
            if rel_spec < 0 :
                reward += rel_spec
                pos_val.append(0)
            else:
                pos_val.append(1)

        # return reward if reward < -0.02 else 10
        return reward if reward < -0.02 else ibias_max

    def update(self, action):
        """

        :param action: an int between 0 ... n-1
        :return:
        """
        #impose constraint tail1 = in
        #params_idx[0] = params_idx[3]
        # params = [self.params[i][params_idx[i]] for i in range(len(self.params_id))]
        # param_val = [OrderedDict(list(zip(self.params_id,params)))]
        #
        # #run param vals and simulate
        # pdb.set_trace()
        cur_specs = OrderedDict(sorted(self.sim_env.create_design_and_simulate(action)[1].items(), key=lambda k:k[0]))
        # cur_specs = np.array(list(cur_specs.values()))

        return cur_specs

    def write_logs(self, state, global_stp) -> None:
        # write per episode
        # write performance
        # pdb.set_trace()
        if self.writer is None:
            print("writer is disabled")
            return

        print(state)
        print(global_stp)
        stp = global_stp
        self.writer.add_scalar('pass', np.array(state['pass']), self.env_steps)
        self.writer.add_scalar('gain', np.array(state['gain']), self.env_steps)
        self.writer.add_scalar('ibias', np.array(state['ibias']), self.env_steps)
        self.writer.add_scalar('phm', np.array(state['phm']), self.env_steps)
        self.writer.add_scalar('ugbw', np.array(state['ugbw']), self.env_steps)
        self.writer.add_scalar('reward', np.array(state['reward']), self.env_steps)
        # order_of_keys = ["mp1", "mn1", "mp3", "mn3", "mn4", "mn5", "cc"]
        invalid = {"cc", 'gain', 'ibias', 'phm', 'ugbw', 'pass', 'reward'}
        self.writer.add_scalars('sizings', {'sizing_of_' + key: np.array(state[key]) for key in
                                            state if key not in invalid}, self.env_steps)
        self.writer.add_scalar('cc', np.array(state['cc']), self.env_steps)

def main():
  corner = {
        'process': 'ff',
        'temp': '17',
        'vdd': '1.3'
    }
  env_config = {"generalize":True, "valid":True, "corner":corner}
  env = TwoStageAmp(env_config)
  env.reset()
  # env.step([2,2,2,2,2,2,2])
  action_dict = {"mp1":2, "mn1":2, "mp3":2, "mn3":2, "mn4":2, "mn5":2, "cc":1e-12}
  env.step(action_dict)
  IPython.embed()

if __name__ == "__main__":
  main()
