from envs.NGspiceOpampEnv_rlpyt import NGspiceOpampEnv_rlpyt
from envs.foldedCascodeEnv_pvt_gen import foldedCascodeEnv_pvt_gen
from envs.strongArmEnv_pvt_gen import strongArmEnv_pvt_gen
import argparse
import importlib
import random
import numpy as np
from tqdm import tqdm

if __name__ == '__main__':
    # import pdb
    # pdb.set_trace()
    parser = argparse.ArgumentParser()
    parser.add_argument('kwargs', default=None)
    parser.add_argument('--log_dir', default='runs')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--corner', type=str)
    parser.add_argument('--temp', type=float, default=27.0)
    parser.add_argument('--vdd', type=float, default=1.2)
    parser.add_argument('--n-samp', type=int, default=10000)
    parser.add_argument('--env', type=str)

    args = parser.parse_args()
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)

    args.kwargs = args.kwargs.replace('/', '.').replace('.py', '')
    kwargs = {}
    prefix = ''
    for name in args.kwargs.split('.')[:-1]:
        prefix += name + '.'
        kwargs = {**kwargs, **importlib.import_module(prefix + 'defaults').kwargs}
    kwargs = {**kwargs, **importlib.import_module(args.kwargs).kwargs}
    kwargs['runs_dir'] = args.kwargs
    kwargs['log_dir'] = args.log_dir
    kwargs['corner'] = {
        'process': args.corner,
        'temp': f"{args.temp}",
        'vdd': f"{args.vdd}",
    }
    if args.env == 'ng':
        env = NGspiceOpampEnv_rlpyt(kwargs)
        obs = env.reset()
    elif args.env == 'fold':
        env = foldedCascodeEnv_pvt_gen(kwargs=kwargs, writer=None)
    elif args.env == 'strongArm':
        env = strongArmEnv_pvt_gen(kwargs=kwargs, writer=None)

    info_all = []
    for k in tqdm(range(args.n_samp)):

        action = env.action_space.sample()
        rew = 0.
        # print('action=', action)
        if args.env == 'ng':
            absolute_sizings = env.get_absolute_sizings(action)
            states, reward, episode_finish, ckt_perf, _ = env.TwoStageAmpEnv.step(absolute_sizings, global_stp=0)

        else:
            states, reward, episode_finish, all_infos, perf_dict, reward_dict = env.step(action=action)
            # import pdb
            # pdb.set_trace()
            ckt_perf = perf_dict
            absolute_sizings = all_infos[0]['absolute_sizings']

        info = {
            'action': action,
            'absolute_sizings': absolute_sizings,
            'states': states,
            'reward': reward,
            'episode_finish': episode_finish,
            'ckt_perf': ckt_perf,
        }

        info_all.append(info)
        if k % 100 == 0:
            if args.env == 'ng':
                np.save(f'generate/{args.corner}_t{args.temp}_v{args.vdd}.npy', np.array(
                    info_all, dtype=object))
            else:
                np.save(f'generate/{args.env}.npy', np.array(
                    info_all, dtype=object))

    if args.env == 'ng':
        np.save(f'generate/{args.corner}_t{args.temp}_v{args.vdd}.npy', np.array(
            info_all, dtype=object))
    else:
        np.save(f'generate/{args.env}.npy', np.array(
            info_all, dtype=object))

