import pdb
import os
import torch
import argparse
import math
import importlib
import random
import time
import datetime

import numpy as np
import torch.backends.cudnn as cudnn

# from envs.strongArmEnv import strongArmEnv
from envs.NGspiceOpampEnv_pvt import NGspiceOpampEnv_pvt
# from envs.NGspiceOpampEnv_rlpyt import NGspiceOpampEnv_rlpyt

# from agent.lib.ddpg import DDPG
from copy import deepcopy
from agent.utils.utils import AttrDict, get_output_folder
from tensorboardX import SummaryWriter

from open_source.rlpyt.rlpyt.samplers.serial.sampler import SerialSampler
from open_source.rlpyt.rlpyt.samplers.parallel.cpu.sampler import CpuSampler
from open_source.rlpyt.rlpyt.envs.atari.atari_env import AtariEnv, AtariTrajInfo
from open_source.rlpyt.rlpyt.algos.qpg.ddpg import DDPG
from open_source.rlpyt.rlpyt.agents.qpg.ddpg_agent import DdpgAgent, DdpgAgent_decay
from open_source.rlpyt.rlpyt.runners.minibatch_rl import MinibatchRl, MinibatchRlEval
from open_source.rlpyt.rlpyt.utils.logging.context import logger_context
import sys

class Envs_with_corners:
    def __init__(self, processes, temps, vdds, kwargs):
        self.envs = {}
        for process_idx in range(len(processes)):
            for temp_idx in range(len(temps)):
                for vdd_idx in range(len(vdds)):
                    process = processes[process_idx]
                    temp = temps[temp_idx]
                    vdd = vdds[vdd_idx]
                    corners["corner_%s_%s_%s" % (process, temp, vdd)] = {
                        "process": process,
                        "temp": temp,
                        "vdd": vdd
                    }
                    kwargs['corner'] = corners["corner_%s_%s_%s" % (process, temp, vdd)]
                    env = NGspiceOpampEnv_rlpyt(kwargs)
                    self.envs["%s_%s_%s" % (process, temp, vdd)] = env

    def get_env(self, process, temp, vdd):
        return self.envs["%s_%s_%s" % (process, temp, vdd)]


def build_and_train(run_id=0, kwargs=None, sample_mode="serial", n_parallel=2, cuda_idx=None, test_date=None):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler #(Ignores worker_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing.")

    # pdb.set_trace()
    # change log setting for eval environments
    kwargs_eval = deepcopy(kwargs)
    kwargs_eval['log'] = True
    log_interval = 10
    kwargs_eval['log_interval'] = log_interval
    sampler = Sampler(
        EnvCls=NGspiceOpampEnv_pvt,
        env_kwargs=dict(kwargs=kwargs),
        eval_env_kwargs=dict(kwargs=kwargs_eval),
        batch_T=1,
        batch_B=1,
        max_decorrelation_steps=1,
        eval_n_envs=1,
        eval_max_steps=1,
    )
    algo = DDPG(
        batch_size=64,
        min_steps_learn=100,
        replay_ratio=640,
        replay_size=1000,
    )
    # agent = DdpgAgent_decay(
    #     action_std=1,
    #     delta_decay=kwargs['delta_decay']
    # )

    agent = DdpgAgent(
        action_std=0.2
    )
    itr = 0
    if test_date is not None:
        # pdb.set_trace()
        log_dir = 'data/example_long_ng_pvt/'
        params_path = '{}{}/run_0/params.pkl'.format(log_dir, test_date)
        data = torch.load(params_path)
        itr = data['itr']
        # pdb.set_trace()
        cum_steps = data['cum_steps']
        agent = DdpgAgent(
            action_std=0.5,
            initial_model_state_dict = data['agent_state_dict']['model'],
            initial_q_model_state_dict = data['agent_state_dict']['q_model']
        )
        buffer_path = '{}{}/buffer.pkl'.format(log_dir, test_date)
        print('buffer loading ...')
        import pickle
        with open(buffer_path, 'rb') as fh:
            replay_buffer = pickle.load(fh)
        algo = DDPG(
            batch_size=2,
            min_steps_learn=10,
            replay_ratio=1,
            replay_size=1000,
            initial_optim_state_dict=data['optimizer_state_dict'],
            replay_buffer_loaded=replay_buffer,
        )

    runner = MinibatchRlEval(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=int(30000),
        log_interval_steps=log_interval,
        affinity=affinity,
        seed=kwargs['seed'],
        start_itr=itr,
    )
    log_params = dict()
    if kwargs['debug']:
        log_dir = "data/example_serial_debug_ng/{}".format(datetime.datetime.today().strftime("%Y%m%d_%H%M"))
    else:
        log_dir = "data/example_long_ng_pvt/{}".format(datetime.datetime.today().strftime("%Y%m%d_%H%M"))
    with logger_context(log_dir, run_id, 'test',
                        snapshot_mode="last",
                        use_summary_writer=True,
                        override_prefix=True):
        runner.train(log_dir=log_dir)


def build_and_test(envs:Envs_with_corners,processes, temps, vdds, log_dir, test_date, test_corner):
    rewards = {}
    ckt_perfs = {}
    pass_count = 0

    # log_dir = 'data/example_long_ng_pvt_seq/'
    if test_date is None:
        exps = os.listdir(log_dir)
        exps.sort(reverse=True)
        test_date = exps[0]
        print('Using the latest experiment with timestamp: {}'.format(test_date))

    if test_corner is None:
        params_path = '{}{}/run_0/params.pkl'.format(log_dir, test_date)
    else:
        corner_info = test_corner.split("_")
        test_process = corner_info[0]
        test_temp = corner_info[1]
        test_vdd = corner_info[2]
        params_path = '{}{}/{}_{}_{}/run_0/params.pkl'.format(log_dir, test_date, test_process, test_temp, test_vdd)

    for process_idx in range(len(processes)):
        for temp_idx in range(len(temps)):
            for vdd_idx in range(len(vdds)):
                process = processes[process_idx]
                temp = temps[temp_idx]
                vdd = vdds[vdd_idx]
                corners["corner_%s_%s_%s" % (process, temp, vdd)] = {
                    "process": process,
                    "temp": temp,
                    "vdd": vdd
                }
                kwargs['corner'] = corners["corner_%s_%s_%s" % (process, temp, vdd)]

                data = torch.load(params_path)
                env = envs.get_env(process, temp, vdd)
                agent = DdpgAgent()
                agent.initialize(env.spaces)
                agent.load_state_dict(data['agent_state_dict'])
                reward, ckt_perf, sizing = evaluate_cur_agent(agent, env)
                if reward > 0:
                    pass_count += 1
                rewards['%s_%s_%s' % (process, temp, vdd)] = reward
                ckt_perfs['%s_%s_%s' % (process, temp, vdd)] = ckt_perf

    return rewards, ckt_perfs, pass_count, sizing


def evaluate_cur_agent(agent, env):
    agent.eval_mode(0)

    observations = []
    obs = env.reset()
    action = env.action_space.sample()
    rew = 0.

    observations.append(obs)
    action, action_info = agent.step(torch.from_numpy(obs).float(),
                                    torch.from_numpy(action).float(),
                                    torch.tensor(rew).float())
    action = action.numpy()
    absolute_sizings = env.get_absolute_sizings(action)
    states, reward, episode_finish, ckt_perf, _ = env.TwoStageAmpEnv.step(absolute_sizings, global_stp=0)

    return reward, ckt_perf, absolute_sizings


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('kwargs', default=None)
    parser.add_argument('--devices', default='0')
    parser.add_argument('--log_dir', default='runs')
    parser.add_argument('--cuda_idx', help='gpu to use ', type=int, default=None)
    parser.add_argument('--sample_mode', help='serial or parallel sampling',
                        type=str, default='serial', choices=['serial', 'cpu'])
    parser.add_argument('--n_parallel', help='number of sampler workers', type=int, default=2)
    parser.add_argument('--debug', help="run a few steps to debug", action="store_true")
    parser.add_argument('--test', dest='test', action='store_true')
    parser.add_argument('--pvt', dest='pvt', action='store_true')
    parser.add_argument('--test_date', type=str, default=None)
    parser.add_argument('--use_mode', type=str, default='serial')
    parser.add_argument('--test_corner', type=str, default=None)
    # parse arguments
    args = parser.parse_args()
    args.kwargs = args.kwargs.replace('/', '.').replace('.py', '')

    # set devices
    if args.devices is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
        cudnn.benchmark = True
        device = 'cuda'
    else:
        device = 'cpu'

    # load kwargs
    kwargs = {}
    prefix = ''
    for name in args.kwargs.split('.')[:-1]:
        prefix += name + '.'
        kwargs = {**kwargs, **importlib.import_module(prefix + 'defaults').kwargs}
    kwargs = {**kwargs, **importlib.import_module(args.kwargs).kwargs}

    # save runs dir for further use
    kwargs['runs_dir'] = args.kwargs
    kwargs['gpu_id'] = args.devices
    kwargs['log_dir'] = args.log_dir
    ts = time.time()
    st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d-%H-%M-%S')
    kwargs['runs_dir'] += '/' + st
    kwargs['corner'] = {
        'process' : 'tt',
        'temp' : '27',
        'vdd' : '1.2'
    }
    # set random seed
    if 'seed' in kwargs and kwargs['seed'] is not None:
        random.seed(kwargs['seed'])
        np.random.seed(kwargs['seed'])
        torch.manual_seed(kwargs['seed'])

        if device == 'cuda':
            cudnn.deterministic = True
            torch.cuda.manual_seed(kwargs['seed'])
    kwargs['output_folder'] = get_output_folder('output', kwargs['runs_dir'])
    # pdb.set_trace()
    kwargs['debug'] = args.debug
    kwargs = AttrDict(kwargs)

    print('==> parsed arguments')
    for k, v in kwargs.items():
        if not k in kwargs['no_print_kwargs']:
            print('[{}] = {}'.format(k, v))

    processes = ["tt", "ss", "ff", "sf", "fs"]
    temps = ["0", "100"]
    vdds = ['0.9', '1.0', '1.1', '1.2']
    processes = kwargs.get('processes', processes)
    temps = kwargs.get('temps', temps)
    vdds = kwargs.get('vdds', vdds)
    corners = {}
    log_dir = 'data/example_long_ng_pvt/'
    # envs = Envs_with_corners(processes=processes, temps=temps, vdds=vdds, kwargs=kwargs)

    if args.test:
        rewards, ckt_perfs, pass_count, sizing = build_and_test(envs=envs, processes=processes, temps=temps, vdds=vdds,
                                                        log_dir=log_dir, test_date=args.test_date,
                                                        test_corner=args.test_corner)
        print("rewards:", rewards)
        print("pass_count:", pass_count)
        print("sizing:", sizing)
        sys.exit(0)

    build_and_train(run_id=0,
                    kwargs=kwargs,
                    sample_mode=args.sample_mode,
                    n_parallel=args.n_parallel,
                    cuda_idx=args.cuda_idx,
                    test_date=args.test_date)

