from collections import namedtuple

from open_source.rlpyt.rlpyt.envs.base import Env, EnvStep, EnvSpaces
from open_source.rlpyt.rlpyt.spaces.int_box import IntBox
from open_source.rlpyt.rlpyt.spaces.float_box import FloatBox
from open_source.rlpyt.rlpyt.utils.quick_args import save__init__args
from open_source.rlpyt.rlpyt.samplers.collections import TrajInfo

import uuid
import subprocess
import cmath
import re
# import paramiko
import time
import multiprocessing
import pdb

import tensorboardX

import numpy as np

# from utils.circuits_utils import get_mos_netlist_static_feature
from envs.StrongArm_Latch.strongArm_latch import strongArm_simulator

from envs.BaseEnv import BaseEnv
from copy import deepcopy

import yaml
import yamlordereddictloader
from envs.StrongArm_Latch import utils
import os

from envs.NGspiceOpamp.ngspice_vanilla_opamp import TwoStageAmp
from tensorboardX import SummaryWriter
import copy

import time
import datetime

import pdb
import random

EnvInfo = namedtuple("EnvInfo", ['timeout'])

class NGspiceOpampEnv_random(Env):
    def __init__(self, kwargs):
        self.global_stp = 0
        self.kwargs = copy.deepcopy(kwargs)
        # for each instance in different processes
        processes = ["tt", "ss", "ff", "sf", "fs"]
        temps = ["0", "100"]
        vdds = ["1.2"]
        self.processes = kwargs.get('processes', processes)
        self.temps = kwargs.get('temps', temps)
        self.vdds = kwargs.get('vdds', vdds)
        self.log = kwargs.get('log', False)
        self.log_interval = kwargs.get('log_interval', 1)
        self.random = kwargs.get('random', False)

        corners = {}
        instance_path = str(uuid.uuid4())
        write_path_root = self.kwargs['log_dir'] + '/' + self.kwargs['runs_dir']
        self.TwoStageAmpEnvs = {}
        for process in self.processes:
            for temp in self.temps:
                for vdd in self.vdds:
                    corners["corner_%s_%s_%s" % (process, temp, vdd)] = {
                        "process": process,
                        "temp": temp,
                        "vdd": vdd,
                    }
                    write_path = write_path_root + "/%s_%s_%s" % (process, temp, vdd)
                    writer = SummaryWriter(log_dir=(write_path))
                    env_config = {"generalize": True, "valid": True, "corner": corners["corner_%s_%s_%s" % (process, temp, vdd)],
                                  "log": self.log}
                    self.TwoStageAmpEnvs["corner_%s_%s_%s" % (process, temp, vdd)] = TwoStageAmp(env_config, tb_writer=writer)
                    self.TwoStageAmpEnvs["corner_%s_%s_%s" % (process, temp, vdd)].reset()

        self._init_obs = np.zeros(162)
        self._action_space = FloatBox(low=-1, high=1, shape=7)
        self._observation_space = FloatBox(low=-1., high=1., shape=162)

    def get_absolute_sizings(self, action):
        # do not use starting points, currently do not round
        # pdb.set_trace()
        env = self.TwoStageAmpEnvs["corner_%s_%s_%s" % (self.processes[0], self.temps[0], self.vdds[0])]
        absolute_sizings = env.sizing_lower + (env.sizing_upper - env.sizing_lower) * (action + 1) / 2
        # absolute_sizings = self.round_sizings(absolute_sizings)
        # print(absolute_sizings)
        sizing_dict = dict()
        for i, varnames in enumerate(list(env.sizing.keys())):
            sizing_dict[varnames] = absolute_sizings[i]
            # eg1 = [absolute_sizings[i]]

        return sizing_dict

    def step(self, action):
        # pdb.set_trace()
        if self.random:
            env = random.choice(list(self.TwoStageAmpEnvs.values()))
            absolute_sizings = self.get_absolute_sizings(action)
            states, reward, episode_finish, ckt_perf, _ = env.step(absolute_sizings, global_stp=self.global_stp)

        else:
            rewards = 0
            for env in self.TwoStageAmpEnvs.values():
                absolute_sizings = self.get_absolute_sizings(action)
                states, reward, episode_finish, ckt_perf, _ = env.step(absolute_sizings, global_stp=self.global_stp)
                rewards += reward
                rewards = rewards / len(self.TwoStageAmpEnvs)
        # pdb.set_trace()
        # raw states is a list of size 360
        # ckt_perf.update(absolute_sizings)
        # self.TwoStageAmpEnv.write_logs(ckt_perf, self.global_stp)
        # self.writer.add_scalar('gain', 1, self.global_stp)

        observation = np.float32(np.array(self._init_obs))
        # raw reward is of np float64
        rewards = np.float32(rewards)
        # raw episode_finish is bool
        episode_finish = True
        # raw all_info is a dict of size 7
        env_info = EnvInfo(timeout=False)
        # self.TwoStageAmpEnv.write_logs(all_info=all_info)
        self.global_stp += self.log_interval - 1

        return EnvStep(observation=observation, reward=rewards, done=episode_finish, env_info=env_info)

    def reset(self):
        init_obs = np.float32(np.array(self._init_obs))
        return init_obs.copy()

    def write_logs(self, state, global_stp) -> None:
        # write per episode
        # write performance
        # pdb.set_trace()
        # print(state)
        # print(global_stp)
        # self.writer.add_scalar('gain', state['gain'], global_stp)
        # self.writer.add_scalar('ibias', state['ibias'], global_stp)
        # self.writer.add_scalar('phm', state['phm'], global_stp)
        # self.writer.add_scalar('ugbw', state['ugbw'], global_stp)
        #
        return None

    @property
    def action_space(self):
        return self._action_space

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def spaces(self):
        return EnvSpaces(
            observation=self.observation_space,
            action=self.action_space,
        )

    @property
    def horizon(self):
        pass

    def close(self):
        pass