import pdb
import os
import torch
import argparse
import math
import importlib
import random
import time
import datetime
import sys

import numpy as np
import torch.backends.cudnn as cudnn

# from envs.strongArmEnv import strongArmEnv
from envs.NGspiceOpampEnv_pvt import NGspiceOpampEnv_pvt
# from envs.NGspiceOpampEnv_rlpyt import NGspiceOpampEnv_rlpyt
# from envs.NGspiceOpampEnv_ewc import NGspiceOpampEnv_ewc
from envs.NGspiceOpampEnv_pcgrad_log import NGspiceOpampEnv_pcgrad_log
from envs.foldedCascodeEnv_pcgrad import foldedCascodeEnv_pcgrad
from envs.foldedCascodeEnv_pcgrad_log import foldedCascodeEnv_pcgrad_log
from envs.strongArmEnv_pcgrad import strongArmEnv_pcgrad
from envs.strongArmEnv_pcgrad_log import strongArmEnv_pcgrad_log
# from agent.lib.ddpg import DDPG
from copy import deepcopy
from agent.utils.utils import AttrDict, get_output_folder
from tensorboardX import SummaryWriter

from open_source.rlpyt.rlpyt.samplers.serial.sampler import SerialSampler, SerialSampler_Parallel
from open_source.rlpyt.rlpyt.samplers.parallel.cpu.sampler import CpuSampler
from open_source.rlpyt.rlpyt.envs.atari.atari_env import AtariEnv, AtariTrajInfo
from open_source.rlpyt.rlpyt.algos.qpg.ddpg import DDPG
from open_source.rlpyt.rlpyt.algos.qpg.ddpg_pcgrad import DDPG_PCGrad
from open_source.rlpyt.rlpyt.algos.qpg.ddpg_pcgrad_parallel import DDPG_PCGrad_Parallel
from open_source.rlpyt.rlpyt.agents.qpg.ddpg_agent import DdpgAgent, DdpgAgent_decay
from open_source.rlpyt.rlpyt.agents.qpg.ddpg_agent_ewc import DdpgAgent_EWC
from open_source.rlpyt.rlpyt.runners.minibatch_rl_pcgrad import MinibatchRlEval_PCGrad
from open_source.rlpyt.rlpyt.runners.minibatch_rl_pcgrad_parallel import MinibatchRlEval_PCGrad_Parallel
from open_source.rlpyt.rlpyt.utils.logging.context import logger_context
from open_source.rlpyt.rlpyt.models.qpg.mlp_ewc import MuMlpModel_EWC
from open_source.rlpyt.rlpyt.samplers.parallel.cpu.collectors import CpuResetCollector_Parallel
import random
from cluster.choose_next import choose_next

def corner_to_id(corner:dict):
    return '%s_%s_%s' % (corner['process'], corner['temp'], corner['vdd'])

class Envs_with_corners:
    def __init__(self, processes, temps, vdds, kwargs, cur_env):
        self.envs = {}
        for index in range(len(processes)):
            process = processes[index]
            temp = temps[index]
            vdd = vdds[index]
            corners["corner_%s_%s_%s" % (process, temp, vdd)] = {
                    "process": process,
                    "temp": temp,
                    "vdd": vdd
            }
            kwargs['corner'] = corners["corner_%s_%s_%s" % (process, temp, vdd)]
            # env = strongArmEnv_pcgrad(kwargs)
            # env = foldedCascodeEnv_pcgrad_log(kwargs=kwargs, writer=None)
            env = cur_env(kwargs=kwargs, writer=None)
            self.envs["%s_%s_%s" % (process, temp, vdd)] = env

    def get_env(self, process, temp, vdd):
        return self.envs["%s_%s_%s" % (process, temp, vdd)]

def build_and_train(run_id=0, kwargs=None, sample_mode="serial", n_parallel=2, cuda_idx=None, load_log_dir=None,
                    train_log_dir=None, start_itr=0, total_stp=128, use_ewc=False, ewc_lambda=400, load_buffer=True,
                    min_steps_learn=100, writer=None, cur_env=None):
    affinity = dict(cuda_idx=cuda_idx, workers_cpus=list(range(n_parallel)))
    gpu_cpu = "CPU" if cuda_idx is None else f"GPU {cuda_idx}"
    if sample_mode == "serial":
        Sampler = SerialSampler_Parallel #(Ignores worker_cpus.)
        print(f"Using serial sampler, {gpu_cpu} for sampling and optimizing.")
    elif sample_mode == "cpu":
        Sampler = CpuSampler
        print(f"Using CPU parallel sampler (agent in workers), {gpu_cpu} for optimizing.")

    initial_optim_state_dict = None
    replay_buffers = None
    initial_model_state_dict = None
    initial_q_model_state_dict = None
    if load_log_dir is not None:
        # pdb.set_trace()
        params_path = load_log_dir + '/run_0/last_params.pkl'
        data = torch.load(params_path)
        if load_buffer:
            import pickle
            buffer_path = load_log_dir + '/buffers.pkl'
            with open(buffer_path, 'rb') as fh:
                replay_buffers = pickle.load(fh)
        initial_optim_state_dict = data['optimizer_state_dict']
        initial_model_state_dict = data['agent_state_dict']['model']
        initial_q_model_state_dict = data['agent_state_dict']['q_model']

    ### construct IDs for buffers ###
    all_task_id = []
    cur_task_id = []
    for i in range(len(kwargs['processes'])):
        id = '%s_%s_%s' % (kwargs['processes'][i], kwargs['temps'][i], kwargs['vdds'][i])
        all_task_id.append(id)
    for i in range(len(kwargs['cur_processes'])):
        id = '%s_%s_%s' % (kwargs['cur_processes'][i], kwargs['cur_temps'][i], kwargs['cur_vdds'][i])
        cur_task_id.append(id)

    algo = DDPG_PCGrad_Parallel(
        batch_size=int(64/len(kwargs['cur_processes'])),
        min_steps_learn=min_steps_learn,
        replay_ratio=int(640/len(kwargs['cur_processes'])),
        replay_size=1000,
        initial_optim_state_dict=initial_optim_state_dict,
        replay_buffer_loaded=replay_buffers,
        num_buffers=len(kwargs['processes']),
        all_task_id=all_task_id,
        cur_task_id=cur_task_id
    )

    agent = DdpgAgent_EWC(
        ModelCls=MuMlpModel_EWC,
        initial_model_state_dict=initial_model_state_dict,
        initial_q_model_state_dict=initial_q_model_state_dict,
        action_std=0.2,
    )
    kwargs_eval = deepcopy(kwargs)
    kwargs_eval['log'] = True
    log_interval = 10
    kwargs_eval['log_interval'] = log_interval
    kwargs_eval['eval'] = True

    cur_corners = []
    cur_corner_infos = []
    for i in range(len(kwargs['cur_processes'])):
        corner = {
            'process' : kwargs['cur_processes'][i],
            'temp' : kwargs['cur_temps'][i],
            'vdd' : kwargs['cur_vdds'][i]
        }
        cur_corners.append(corner)
        cur_corner_infos.append('%s_%s_%s' % (corner['process'], corner['temp'], corner['vdd']))
    all_corners = []
    all_corner_infos = []
    for i in range(len(kwargs['processes'])):
        corner = {
            'process': kwargs['processes'][i],
            'temp': kwargs['temps'][i],
            'vdd': kwargs['vdds'][i]
        }
        all_corners.append(corner)
        all_corner_infos.append('%s_%s_%s' % (corner['process'], corner['temp'], corner['vdd']))
    ### initialize samplerss labelled by id ###
    samplers = dict.fromkeys(cur_corner_infos)
    # for index, corner in enumerate(cur_corners):
    #     cur_kwargs = deepcopy(kwargs)
    #     cur_kwargs_eval = deepcopy(kwargs_eval)
    #     cur_kwargs['taskID'] = cur_corner_infos[index]
    #     cur_kwargs_eval['taskID'] = cur_corner_infos[index]
    #     cur_kwargs['corner'] = cur_corners[index]
    #     cur_kwargs_eval['corner'] = cur_corners[index]
    #     sampler = Sampler(
    #         # EnvCls=strongArmEnv_pcgrad_log,
    #         # EnvCls=foldedCascodeEnv_pcgrad_log,
    #         EnvCls=cur_env,
    #         env_kwargs=dict(kwargs=cur_kwargs, writer=None),
    #         eval_env_kwargs=dict(kwargs=cur_kwargs_eval, writer=None),
    #         batch_T=1,
    #         batch_B=1,
    #         max_decorrelation_steps=1,
    #         eval_n_envs=1,
    #         eval_max_steps=1,
    #         CollectorCls=CpuResetCollector_Parallel,
    #     )
    #     samplers[cur_corner_infos[index]] = sampler
    ### list of dicts ###
    env_kwargs_list = []
    eval_env_kwargs_list = []
    for index, corner in enumerate(cur_corners):
        cur_kwargs = deepcopy(kwargs)
        cur_kwargs_eval = deepcopy(kwargs_eval)
        cur_kwargs['taskID'] = cur_corner_infos[index]
        cur_kwargs_eval['taskID'] = cur_corner_infos[index]
        cur_kwargs['corner'] = cur_corners[index]
        cur_kwargs_eval['corner'] = cur_corners[index]
        env_kwargs_list.append(dict(kwargs=cur_kwargs, writer=None))
        eval_env_kwargs_list.append(dict(kwargs=cur_kwargs, writer=None))
    all_eval_env_kwargs_list = []
    for index, corner in enumerate(all_corners):
        cur_kwargs_eval = deepcopy(kwargs_eval)
        cur_kwargs_eval['taskID'] = all_corner_infos[index]
        cur_kwargs_eval['corner'] = all_corners[index]
        all_eval_env_kwargs_list.append(dict(kwargs=cur_kwargs_eval, writer=None))

    sampler = Sampler(
        EnvCls=cur_env,
        env_kwargs_list=env_kwargs_list,
        eval_env_kwargs_list=eval_env_kwargs_list,
        batch_T=1,
        batch_B=len(env_kwargs_list),
        # batch_B=1,
        max_decorrelation_steps=1,
        eval_n_envs=len(eval_env_kwargs_list),
        eval_max_steps=len(eval_env_kwargs_list),
        CollectorCls=CpuResetCollector_Parallel,
        all_eval_env_kwargs_list=all_eval_env_kwargs_list,
    )
    samplers[cur_corner_infos[0]] = sampler
    runner = MinibatchRlEval_PCGrad_Parallel(
        algo=algo,
        agent=agent,
        sampler=sampler,
        n_steps=total_stp * len(env_kwargs_list),
        log_interval_steps=log_interval * len(env_kwargs_list),
        affinity=affinity,
        seed=kwargs['seed'],
        start_itr = 0
    )
    final_step = 0
    with logger_context(train_log_dir, run_id, 'train_pcgrad',
                        snapshot_mode="last",
                        use_summary_writer=True,
                        override_prefix=True):
        final_step = runner.train(train_log_dir, writer)
    return final_step


def build_and_test(envs:Envs_with_corners,processes, temps, vdds, log_dir, writer=None, step=0, random_agent=False):
    rewards = {}
    ckt_perfs = {}
    passes = {}
    pass_count = 0

    params_path = '{}/run_0/last_params.pkl'.format(log_dir)
    # pdb.set_trace()
    for index in range(len(processes)):
        process = processes[index]
        temp = temps[index]
        vdd = vdds[index]
        corners["corner_%s_%s_%s" % (process, temp, vdd)] = {
                "process": process,
                "temp": temp,
                "vdd": vdd
            }
        kwargs['corner'] = corners["corner_%s_%s_%s" % (process, temp, vdd)]

        data = torch.load(params_path)
        env = envs.get_env(process, temp, vdd)
        # agent = DdpgAgent()
        agent = DdpgAgent_EWC(
            ModelCls=MuMlpModel_EWC,
            initial_model_state_dict=data['agent_state_dict']['model'],
            initial_q_model_state_dict=data['agent_state_dict']['q_model'],
        )
        agent.initialize(env.spaces)
        agent.load_state_dict(data['agent_state_dict'])
        if random_agent:
            agent = None
        reward, ckt_perf, sizing = evaluate_cur_agent(agent, env)
        env_pass = 0
        if reward > 0:
            pass_count += 1
            env_pass = 1

        rewards['%s_%s_%s' % (process, temp, vdd)] = reward
        ckt_perfs['%s_%s_%s' % (process, temp, vdd)] = ckt_perf
        passes['%s_%s_%s' % (process, temp, vdd)] = env_pass

    if writer is not None:
        # pdb.set_trace()
        writer.add_scalars('passes', passes, step)
        writer.add_scalars('rewards', rewards, step)
        writer.add_scalar('pass_count', pass_count, step)

    worst_corner = min(rewards, key=rewards.get)

    return rewards, ckt_perfs, pass_count, sizing, worst_corner


def evaluate_cur_agent(agent, env):
    observations = []
    obs = env.reset()
    action = env.action_space.sample()
    rew = 0.

    observations.append(obs)
    if agent is not None:
        agent.eval_mode(0)
        action, action_info = agent.step(torch.from_numpy(obs).float(),
                                    torch.from_numpy(action).float(),
                                    torch.tensor(rew).float())
        action = action.numpy()
    # pdb.set_trace()
    absolute_sizings = env.get_absolute_sizings(action)
    # # states, reward, episode_finish, ckt_perf, _ = env.TwoStageAmpEnv.step(absolute_sizings, global_stp=0)
    # process = env.cur_corner['process']
    # vdd = env.cur_corner['vdd']
    # temp = env.cur_corner['temp']
    # cur_corner = "%s_%s_%s" % (process, temp, vdd)
    states, reward, episode_finish, _ = env.step(action)
    ckt_perf = env.get_ckt_perf()

    return reward, ckt_perf, absolute_sizings




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('kwargs', default=None)
    parser.add_argument('--devices', default='0')
    parser.add_argument('--log_dir', default='runs')
    parser.add_argument('--cuda_idx', help='gpu to use ', type=int, default=None)
    parser.add_argument('--sample_mode', help='serial or parallel sampling',
                        type=str, default='serial', choices=['serial', 'cpu'])
    parser.add_argument('--n_parallel', help='number of sampler workers', type=int, default=2)
    parser.add_argument('--debug', help="run a few steps to debug", action="store_true")
    parser.add_argument('--ctd', help="load state_dict", default=True)
    parser.add_argument('--random_choose', action="store_true")
    parser.add_argument('--test', action="store_true")
    parser.add_argument('--date', type=str)
    parser.add_argument('--env', type=str)
    parser.add_argument('--N', type=str)

    # parse arguments
    args = parser.parse_args()
    args.kwargs = args.kwargs.replace('/', '.').replace('.py', '')

    # choose env
    cur_env = None
    if args.env == 'strongArm':
        cur_env = strongArmEnv_pcgrad_log
    elif args.env == 'fold':
        cur_env = foldedCascodeEnv_pcgrad_log
    elif args.env == 'ng':
        cur_env = NGspiceOpampEnv_pcgrad_log
    # set devices
    if args.devices is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
        cudnn.benchmark = True
        device = 'cuda'
    else:
        device = 'cpu'

    # load kwargs
    kwargs = {}
    prefix = ''
    for name in args.kwargs.split('.')[:-1]:
        prefix += name + '.'
        kwargs = {**kwargs, **importlib.import_module(prefix + 'defaults').kwargs}
    kwargs = {**kwargs, **importlib.import_module(args.kwargs).kwargs}

    # save runs dir for further use
    kwargs['runs_dir'] = args.kwargs
    kwargs['gpu_id'] = args.devices
    kwargs['log_dir'] = args.log_dir
    ts = time.time()
    st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d-%H-%M-%S')
    kwargs['runs_dir'] += '/' + st
    kwargs['corner'] = {
        'process' : 'tt',
        'temp' : '27',
        'vdd' : '1.2'
    }
    # set random seed
    if 'seed' in kwargs and kwargs['seed'] is not None:
        random.seed(kwargs['seed'])
        np.random.seed(kwargs['seed'])
        torch.manual_seed(kwargs['seed'])

        if device == 'cuda':
            cudnn.deterministic = True
            torch.cuda.manual_seed(kwargs['seed'])
    kwargs['output_folder'] = get_output_folder('output', kwargs['runs_dir'])
    kwargs = AttrDict(kwargs)

    processes = ["tt", "fs"]
    # processes = ["ss", "tt"]
    temps = ["0", "100"]
    # temps = ["100", "0"]
    vdds = ["1.2", "0.9"]
    # vdds = ["0.9", "1.2"]

    corners = {}
    processes = kwargs.get('processes', processes)
    temps = kwargs.get('temps', temps)
    vdds = kwargs.get('vdds', vdds)

    train_log_dir_root = 'data/example_pcgrad/%s' % (args.env)
    if args.debug:
        train_log_dir_root += '/debug'
    train_log_dir_root += '/' + st
    envs = Envs_with_corners(processes=processes, temps=temps, vdds=vdds, kwargs=kwargs, cur_env=cur_env)
    start_itr = 0
    runs_dir_root = kwargs['runs_dir']
    ts = time.time()
    st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d-%H-%M')
    logger_path = '{}/test-'.format(train_log_dir_root) + st
    if not args.test:
        writer = SummaryWriter(log_dir=logger_path)
    else:
        writer = None
    corners = {}
    #total_stps = {}
    first_proc = processes[0]
    first_temp = temps[0]
    first_vdd = vdds[0]
    # kwargs['processes'] = [first_proc]
    # kwargs['temps'] = [first_temp]
    # kwargs['vdds'] = [first_vdd]
    for index in range(len(processes)):
        process = processes[index]
        temp = temps[index]
        vdd = vdds[index]
        corners["corner_%s_%s_%s" % (process, temp, vdd)] = {
                "process": process,
                "temp": temp,
                "vdd": vdd
        }
        #total_stps['%s_%s_%s' % (process, temp, vdd)] = kwargs.get('total_stps')[index]

    kwargs['corner'] = corners['corner_%s_%s_%s' % (processes[0], temps[0], vdds[0])]
    load_log_dir = None
    cur_corner = '%s_%s_%s' % (kwargs['corner']['process'], kwargs['corner']['temp'], kwargs['corner']['vdd'])
    worst_corners = [cur_corner]
    random_corners = [cur_corner]
    steps = []
    next_lists = []
    min_steps = 100

    if args.test:
        train_log_dir_root = 'data/example_pcgrad/%s' % (args.env)
        train_log_dir = train_log_dir_root + '/' + args.date
        exps = os.listdir(train_log_dir)
        exps.sort(reverse=True)
        exps = [exp for exp in exps if '%s' % (args.N) in exp]
        if args.N == 'Nr':
            random_agent = True
            train_log_dir += '/' + 'N0'
        else:
            random_agent = False
            train_log_dir += '/' + exps[0]
        rewards, ckt_perfs, pass_count, sizing, worst_corner = build_and_test(envs=envs,
                                                                      processes=processes,
                                                                      temps=temps,
                                                                      vdds=vdds,
                                                                      log_dir=train_log_dir,
                                                                      writer=writer,
                                                                      step=0,
                                                                      random_agent= random_agent,
                                                                        )
        print("using data of {}".format(train_log_dir))
        print('rewards:', rewards)
        print('perfs:', ckt_perfs)
        import pickle
        if cur_env == foldedCascodeEnv_pcgrad_log:
            file_to_write = open("dist/fold_%s.pkl" % (args.N), "wb")
        elif cur_env == strongArmEnv_pcgrad_log:
            file_to_write = open("dist/strongArm_%s.pkl" % (args.N), "wb")
        elif cur_env == NGspiceOpampEnv_pcgrad_log:
            file_to_write = open("dist/ng_%s.pkl" % (args.N), "wb")
        pickle.dump(ckt_perfs, file_to_write)
        sys.exit()

    for step in range(15):
        cur_corner = '%s_%s_%s' % (kwargs['corner']['process'],kwargs['corner']['temp'], kwargs['corner']['vdd'])
        ### dir to save cur params ###
        train_log_dir = train_log_dir_root + '/' + 'N%s' % (step)
        ### dir to save individual env params ###
        # kwargs['runs_dir'] = runs_dir_root + '-' + 'N%s' % (step)
        writer_path = runs_dir_root + '-' + 'N%s' % (step)
        train_writer = SummaryWriter(log_dir=(writer_path))
        itr_step = build_and_train(run_id=0,
                        kwargs=kwargs,
                        sample_mode=args.sample_mode,
                        n_parallel=args.n_parallel,
                        cuda_idx=args.cuda_idx,
                        load_log_dir=load_log_dir,
                        train_log_dir=train_log_dir,
                        start_itr=start_itr,
                        total_stp=6000,
                        use_ewc=kwargs['use_ewc'],
                        ewc_lambda=kwargs['ewc_lambda'],
                        load_buffer=kwargs['load_buffer'],
                        min_steps_learn=min_steps,
                        writer=train_writer,
                        cur_env=cur_env,
                        )
        train_writer.close()
        rewards, ckt_perfs, pass_count, sizing, worst_corner = build_and_test(envs=envs,
                                                                      processes=processes,
                                                                      temps=temps,
                                                                      vdds=vdds,
                                                                      log_dir=train_log_dir,
                                                                      writer=writer,
                                                                      step=step,
                                                                      )

        # min_steps = 0
        next_proc_list = []
        next_temp_list = []
        next_vdd_list = []
        # choose worse corners overall
        # next_corners = {k : v for k, v in rewards.items() if v < 0}
        # next_corners_key = sorted(next_corners, key=next_corners.get)
        # next_batch_size = 4
        # if cur_env == foldedCascodeEnv_pcgrad_log:
        #     next_batch_size = 4
        # if len(next_corners_key) > next_batch_size:
        #     next_corners_key = next_corners_key[:next_batch_size]
        # sorted_next = {}
        # for key in next_corners_key:
        #     sorted_next[key] = next_corners[key]
        # next_corners = sorted_next
        # choose worst for each cluster
        if cur_env == foldedCascodeEnv_pcgrad_log:
            env_input = 'fold'
            num_clusters = 4
        elif cur_env == strongArmEnv_pcgrad_log:
            env_input = 'strongArm'
            num_clusters = 2
        elif cur_env == NGspiceOpampEnv_pcgrad_log:
            env_input = 'ng'
            num_clusters = 4
        next_corners = choose_next(data=ckt_perfs, env=env_input, num_clusters=num_clusters)
        if cur_env == foldedCascodeEnv_pcgrad_log:
            tt_corner = 'tt_27_1.8'
        else:
            tt_corner = 'tt_27_1.2'
        if tt_corner not in next_corners:
            next_corners[tt_corner] = 0
        next_lists.append(next_corners)
        for corner in next_corners:
            info = corner.split('_')
            proc = info[0]
            temp = info[1]
            vdd = info[2]
            next_proc_list.append(proc)
            next_temp_list.append(temp)
            next_vdd_list.append(vdd)
        kwargs['cur_processes'] = next_proc_list
        kwargs['cur_temps'] = next_temp_list
        kwargs['cur_vdds'] = next_vdd_list
        steps.append(itr_step)
        if rewards[worst_corner] > 0:
            break
        worst_corners.append(worst_corner)
        worst = worst_corner.split('_')
        worst_process = worst[0]
        worst_temp = worst[1]
        worst_vdd = worst[2]
        kwargs['corner'] = {'process' : worst_process,
                            'temp' : worst_temp,
                            'vdd' : worst_vdd}
        # pdb.set_trace()
        # corner_dir = '%s_%s_%s' % (kwargs['corner']['process'], kwargs['corner']['temp'], kwargs['corner']['vdd'])
        if args.random_choose:
            # pdb.set_trace()
            next_corner = {n : p for n, p in rewards.items() if p < 0}
            next_chosen = random.choice(list(next_corner))
            random_corners.append(next_chosen)
            next = next_chosen.split('_')
            next_process = next[0]
            next_temp = next[1]
            next_vdd = next[2]
            kwargs['corner'] = {'process': next_process,
                                'temp': next_temp,
                                'vdd': next_vdd}
        if args.ctd:
            load_log_dir = train_log_dir
        # train_log_dir = train_log_dir_root + '/' + worst_corner

    writer.close()
    print('worst corners:', worst_corners)
    print('random conrers:', random_corners)
    print('steps of each run:', steps)
    print('corners for each run:', next_lists)










