import os
import torch
import random
import numpy as np
import warnings

from collections import OrderedDict
import yaml
import yaml.constructor
import statistics
import os
import itertools
from eval_engines.util.core import *
import pickle
from eval_engines.ngspice.TwoStageClass import *
import shutil

from geneticalgorithm import geneticalgorithm as ga
import math

cir_name = 'cir1_GA'
run_name = "run1"

warnings.filterwarnings("ignore")
print('-'*50)
if os.path.exists("cir/" + run_name + "/designs_" + cir_name):
    shutil.rmtree("cir/" + run_name + "/designs_" + cir_name)
os.makedirs(cir_name, exist_ok=True)
wtdir = "sizing/"+ cir_name + "/" + cir_name + ".log"
log = open(wtdir, 'w')

#way of ordering the way a yaml file is read
class OrderedDictYAMLLoader(yaml.Loader):
    """
    A YAML loader that loads mappings into ordered dictionaries.
    """

    def __init__(self, *args, **kwargs):
        yaml.Loader.__init__(self, *args, **kwargs)

        self.add_constructor(u'tag:yaml.org,2002:map', type(self).construct_yaml_map)
        self.add_constructor(u'tag:yaml.org,2002:omap', type(self).construct_yaml_map)

    def construct_yaml_map(self, node):
        data = OrderedDict()
        yield data
        value = self.construct_mapping(node)
        data.update(value)

    def construct_mapping(self, node, deep=False):
        if isinstance(node, yaml.MappingNode):
            self.flatten_mapping(node)
        else:
            raise yaml.constructor.ConstructorError(None, None,
                                                    'expected a mapping node, but found %s' % node.id, node.start_mark)

        mapping = OrderedDict()
        for key_node, value_node in node.value:
            key = self.construct_object(key_node, deep=deep)
            value = self.construct_object(value_node, deep=deep)
            mapping[key] = value
        return mapping

class TwoStageAmp():
    PERF_LOW = -1
    PERF_HIGH = 0

    #obtains yaml file
    path = os.getcwd()
    CIR_YAML = path+'/eval_engines/ngspice/ngspice_inputs/yaml_files/' + cir_name + '.yaml'

    def __init__(self):
        # self.candelete = 0
        with open(TwoStageAmp.CIR_YAML, 'r') as f:
            yaml_data = yaml.load(f, OrderedDictYAMLLoader)
        
        params = yaml_data['params']
        self.params = []
        self.params_id = list(params.keys())
        self.sim_env = TwoStageClass(yaml_path=TwoStageAmp.CIR_YAML, num_process=1, path=TwoStageAmp.path, cir_path = "cir/" + run_name) 

    def step(self, action):
        """
        :param action: is vector with elements between 0 and 1 mapped to the index of the corresponding parameter
        :return:
        """

        #Take action that RL agent returns to change current params
        cur_params_idx = np.array(action)
        cur_specs = self.update(cur_params_idx)
        reward, UGF, GBW, Gain, BW, Power, Phm = self.opt(cur_specs)

        return reward, UGF, GBW, Gain, BW, Power, Phm

    def opt(self, spec):
        '''
        Reward: doesn't penalize for overshooting spec, is negative
        '''
        # if spec['gain'] < 1:
        #     Gain = 0
        # elif spec['gain'] >= 1:
        #     Gain = 20*math.log(spec['gain'],10)

        if spec['bw'] < 0:
            BW = 0
        else:
            BW = spec['bw'] / 1000000
        
        if spec['ugbw'] < 0:
            UGF = 0
        else:
            UGF = spec['ugbw'] / 1000000

        Power = (spec['ibias']*1000000)*1.2  
        # GBW = UGF
        if spec['gain'] > 1:
            GBW = spec['gain']*BW
        else:
            GBW = 0
        # print([Gain, BW, Power])
        reward = (GBW*10)/(Power)
        # print(reward)
            
        Phm = spec['phm']
        if spec['gain']<317 or Phm<60:
            if spec['gain'] < 317:
                if spec['gain'] < 0:
                    Gain_reward = (spec['gain']-317)/(317-spec['gain'])
                else:
                    Gain_reward = (spec['gain']-317)/(317+spec['gain'])
            else:
                Gain_reward = 0

            if Phm < 60:
                if Phm < 0:
                    Phm_reward = (Phm-60)/(60-Phm)
                else:
                    Phm_reward = (Phm-60)/(60+Phm)
            else:
                Phm_reward = 0

            reward = Gain_reward + Phm_reward
        return reward, UGF, GBW, spec['gain'], BW, Power, Phm

    # def opt(self, spec):
    #     '''
    #     Reward: doesn't penalize for overshooting spec, is negative
    #     '''
    #     if spec['gain'] < 1:
    #         Gain = 0
    #     elif spec['gain'] >= 1:
    #         Gain = 20*math.log(spec['gain'],10)

    #     if spec['bw'] < 0:
    #         BW = 0
    #     else:
    #         BW = spec['bw'] / 1000000
        
    #     if spec['ugbw'] < 0:
    #         UGF = 0
    #     else:
    #         UGF = spec['ugbw'] / 1000000

    #     Power = (spec['ibias']*1000000)*1.2  
    #     GBW = (spec['gain']*spec['bw'])/1000000
    #     print([Gain, BW, UGF, Power])
    #     reward = (GBW*10)/(Power)
    #     # print(reward)

    #     return reward, GBW, Power

    def update(self, params_idx):
        """

        :param action: an int between 0 ... n-1
        :return:
        """
        # print(self.candelete)
        # if self.candelete == 1:
        #     shutil.rmtree(self.mydesignpath)
        # else:
        #     self.candelete = 1
        params = []
        for i,rel_params in enumerate(params_idx):
            # if(self.params_id[i] == 'mp1') or (self.params_id[i] == 'mp2') or (self.params_id[i] == 'mp3') or (self.params_id[i] == 'mn1') or (self.params_id[i] == 'mn2') or (self.params_id[i] == 'mn3') or (self.params_id[i] == 'mn4') or (self.params_id[i] == 'mn5'):
            #     params.append(rel_params)
            #     continue
            if(self.params_id[i] == 'lp1') or (self.params_id[i] == 'lp2') or (self.params_id[i] == 'lp3') or (self.params_id[i] == 'ln1') or (self.params_id[i] == 'ln2') or (self.params_id[i] == 'ln3') or (self.params_id[i] == 'ln4') or (self.params_id[i] == 'ln5'):
                params.append(rel_params*50e-9)
                continue
            elif(self.params_id[i] == 'c1'):
                params.append(rel_params*0.1e-12)
                continue
            # elif (self.params_id[i] == 'ibias'):
            #     params.append(rel_params)
            #     continue
            params.append(rel_params)
        param_val = [OrderedDict(list(zip(self.params_id,params)))]
        # print(param_val)

        self.mystate, self.myspec, self.myinfo, self.mydesignpath = self.sim_env.create_design_and_simulate(param_val[0])
        shutil.rmtree(self.mydesignpath)
        # print(self.mydesignpath)

        
        cur_specs = OrderedDict(sorted(self.myspec.items(), key=lambda k:k[0]))
        # print(cur_specs)
        # cur_specs = np.array(list(cur_specs.values()))

        return cur_specs

def write_log(env, step, log, best_reward, best_para, best_ugf, best_gbw, best_gain, best_bw, best_power, best_phm):
    # best_para = env.denorm(best_para)
    log.write(str(step) + "\t" + str(best_reward) + "\t" + str(best_ugf) + "\t" + str(best_gbw) + "\t"  + str(best_gain) + "\t"  + str(best_bw) + "\t" + str(best_power) + "\t" + str(best_phm) + "\t")
    best_para_print = []
    for i,rel_params in enumerate(best_para):
        # if(env.params_id[i] == 'mp1') or (env.params_id[i] == 'mp2') or (env.params_id[i] == 'mp3') or (env.params_id[i] == 'mn1') or (env.params_id[i] == 'mn2') or (env.params_id[i] == 'mn3') or (env.params_id[i] == 'mn4') or (env.params_id[i] == 'mn5'):
        #     log.write(str(rel_params*0.5) + "\t")
        #     best_para_print.append(rel_params*0.5)
        #     continue
        if(env.params_id[i] == 'lp1') or (env.params_id[i] == 'lp2') or (env.params_id[i] == 'lp3') or (env.params_id[i] == 'ln1') or (env.params_id[i] == 'ln2') or (env.params_id[i] == 'ln3') or (env.params_id[i] == 'ln4') or (env.params_id[i] == 'ln5'):
            log.write(str(rel_params*50e-9) + "\t")
            best_para_print.append(rel_params*50e-9)
            continue
        elif(env.params_id[i] == 'c1'):
            log.write(str(rel_params*0.1e-12) + "\t")
            best_para_print.append(rel_params*0.1e-12)
            continue
        # elif(env.params_id[i] == 'ibias'):
        #     log.write(str(rel_params) + "\t")
        #     best_para_print.append(rel_params)
        #     continue
        log.write(str(rel_params) + "\t")
        best_para_print.append(rel_params)
    log.write("\n")
    log.flush()
    print(f"Simulation steps: {step} Best reward: {best_reward} Best parameter: {best_para_print}")

simulation = 0
best_reward = float('-inf')
best_parameter = 0
best_ugf = 0
best_gbw = 0
best_gain = 0
best_bw = 0
best_power = 0
best_phm = 0
def target_function(parameter):
    env = TwoStageAmp()
    global simulation
    global best_reward
    global best_parameter
    global best_ugf
    global best_gbw
    global best_gain
    global best_bw
    global best_power
    global best_phm
    # print(parameter)
    reward, UGF, GBW, Gain, BW, Power, Phm = env.step(parameter)
    simulation = simulation + 1
    if reward > best_reward:
        best_reward = reward
        best_parameter = parameter
        best_ugf = UGF
        best_gbw = GBW
        best_gain = Gain
        best_bw = BW
        best_power = Power
        best_phm = Phm
    write_log(env, simulation, log, best_reward, best_parameter, best_ugf, best_gbw, best_gain, best_bw, best_power, best_phm)
    return reward*-1

varbound=np.array([[1.0e-6, 1.0e-4],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000],[1, 1000]])
vartype=np.array([['real'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int'],['int']])

# algorithm_param = {'max_num_iteration': 350,\
#                    'population_size':10,\
#                    'mutation_probability':0.001,\
#                    'elit_ratio': 0,\
#                    'crossover_probability': 0.5,\
#                    'parents_portion': 0.1,\
#                    'crossover_type':'uniform',\
#                    'max_iteration_without_improv':None}

algorithm_param = {'max_num_iteration': 800,\
                   'population_size':2000,\
                   'mutation_probability':0.1,\
                   'elit_ratio': 0.01,\
                   'crossover_probability': 0.5,\
                   'parents_portion': 0.3,\
                   'crossover_type':'uniform',\
                   'max_iteration_without_improv':None}

model=ga(function=target_function,dimension=18,variable_type_mixed=vartype,variable_boundaries=varbound,algorithm_parameters=algorithm_param,function_timeout = 100)

model.run()
# convergence=model.report
# solution=model.ouput_dict
print("Done!")
log.close()