from typing import List
import exp_utils as PQ
import torch
import torch.nn as nn
from torch.nn.functional import relu, softplus
import numpy as np
import pickle
import pytorch_lightning as pl
from loguru import logger

from rl_utils.runner import merge_episode_stats, RunnerX, EpisodeReturn, ExtractLastInfo, RunnerWithModel
from copy import deepcopy
from safe import *
from safe.debugger import Debugger

import safe.envs
from safe.envs import make_env
from rl_utils import MLP
import rl_utils


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class FLAGS(PQ.BaseFLAGS):
    _strict = False

    class model(PQ.BaseFLAGS):
        type = 'learned'
        n_ensemble = 5
        n_elites = 0
        frozen = True
        train = TransitionModel.FLAGS

    class ckpt(PQ.BaseFLAGS):
        L = ''
        policy = ''
        models = ''
        s = ''
        x_vs_L = ''

        buf = ''
        safe_invariant = ''
        model_trainers = ''

    class fix(PQ.BaseFLAGS):
        L = False
        policy = False
        normalizer = True
        model = False

    env = safe.envs.FLAGS
    SAC = SafeSACTrainer2.FLAGS
    lyapunov = Lyapunov.FLAGS
    opt_s = SLangevinOptimizer.FLAGS
    opt_L = LOptimizer.FLAGS
    obj = ObjEvaluator.FLAGS

    n_iters = 500000
    n_plot_iters = 10000
    n_eval_iters = 1000
    n_save_iters = 10000
    n_pretrain_s_iters = 10000
    use_fake_L = False
    task = 'train'
    win_streak = 5
    streak_threshold = 0.0
    bf_safe_policy = False


rng = np.random.RandomState(1)


def sample_state_space(n):
    theta = rng.uniform(-1, 1, size=(n, 1)) * np.pi / 2
    vel = rng.uniform(-1, 1, size=(n, 1))
    return torch.tensor(np.hstack([theta, vel]), dtype=torch.float32)


class DetMLPPolicy(MLP, rl_utils.DetNetPolicy):
    pass


class MLPQFn(MLP, rl_utils.NetQFn):
    pass


class TanhGaussianMLPPolicy(rl_utils.policy.TanhGaussianPolicy, MLP, rl_utils.NetPolicy):
    pass


def bf_optimize(policy, s, U_pi):
    breakpoint()
    opt = torch.optim.Adam(policy.parameters())
    mask = torch.ones(len(s), device=s.device)

    s = s.detach().requires_grad_()
    for i in range(100000):
        obj = relu(U_pi(s) - 1) * mask
        if i % 1000 == 0:
            print(obj.mean(), obj.max())
            breakpoint()

        loss = obj.mean()
        opt.zero_grad()
        loss.backward()
        opt.step()


def explore(t, runners, n_samples, expl_policy, buf_real):
    ep_infos = runners['explore'].run(expl_policy, n_samples, buffer=buf_real)
    # ep_infos = runners['explore'].run(policy, horizon, buffer=buf_real)
    merged_infos = merge_episode_stats(ep_infos)
    n_expl_unsafe_states = sum([info.get('episode.unsafe', 0) for info in ep_infos])
    if n_expl_unsafe_states > 0:
        method = logger.critical
    else:
        method = logger.info
    method(f"[explore] # {t}: # expl unsafe states = {n_expl_unsafe_states}, "
           f"expl trajs return: {merged_infos['return']}")
    PQ.writer.add_scalar('violation/expl_policy', n_expl_unsafe_states, global_step=t)
    PQ.writer.add_scalar('policy/expl/return', merged_infos['return'][0])


def bump(t, mod):  # find min x > t such that x % mod == 0
    return (t // mod + 1) * mod


def main():
    import logging
    PQ.init(FLAGS)
    logging.getLogger('lightning').setLevel(0)
    from pytorch_lightning.loggers import WandbLogger
    wandb_logger = WandbLogger(save_dir=PQ.log_dir, name=PQ.log_dir.name)
    wandb_logger.experiment.config.update(FLAGS.to_dict())

    PQ.log.info(f"wandb url = {wandb_logger.experiment.url}")

    env = make_env()
    s0 = torch.tensor(env.reset(), device=device, dtype=torch.float32)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    normalizer = Normalizer(dim_state, clip=1000).to(device)

    if FLAGS.ckpt.buf:
        with open(FLAGS.ckpt.buf, 'rb') as f:
            buf_real = pickle.load(f)
            logger.warning(f"load model buffer from {FLAGS.ckpt.buf}")
    else:
        buf_real = rl_utils.TorchReplayBuffer(env, max_buf_size=1000_000)

    buf_dev = rl_utils.TorchReplayBuffer(env, max_buf_size=10_000)
    # policy = DetMLPPolicy([dim_state, 64, 64, dim_action], auto_squeeze=False, output_activation=nn.Tanh).to(device)
    # mean_policy = policy
    policy = TanhGaussianMLPPolicy([dim_state, 64, 64, dim_action * 2]).to(device)
    # unsafe_policy = TanhGaussianMLPPolicy([dim_state, 64, 64, dim_action * 2]).to(device)
    mean_policy = rl_utils.policy.MeanPolicy(policy)

    if FLAGS.env.id == 'MySafexp-PointGoal1-v1':
        logger.warning("use DomainModel")
        make_model = lambda: DomainModel(env, env.hazards_pos, env.vases_pos, env.goal_pos)
    else:
        make_model = lambda i:  \
            TransitionModel(dim_state, normalizer, [dim_state + dim_action, 256, 256, 256, 256, dim_state * 2],
                            name=f'model-{i}')
        # make_model = lambda i: StableDynamics(
        #     TransitionModel(dim_state, normalizer, [dim_state + dim_action, 256, 256, 256, 256, dim_state * 2]),
        #     dim_state, 0.01, buf=buf_real, buf_dev=None, name=f'model_{i}').to(device)
    ensemble = EnsembleModel([make_model(i) for i in range(FLAGS.model.n_ensemble)])
    # model_trainers = [ModelTrainer(model, buf_real, buf_dev, device=device, name=f'model/{i}')
    #                   for i, model in enumerate(ensemble.models)]

    model_trainers = [pl.Trainer(
        gradient_clip_val=10., max_epochs=50, default_root_dir=PQ.log_dir, progress_bar_refresh_rate=0, logger=wandb_logger)
        for _ in ensemble.models]
    horizon = env.spec.max_episode_steps
    make_stats = [lambda: ExtractLastInfo('episode.unsafe'), lambda: EpisodeReturn()]
    runners = {
        'explore': RunnerX(make_env, 1, make_stats, device=device),
        'evaluate': RunnerX(make_env, 1, make_stats, device=device),
        'test': RunnerX(make_env, 1, make_stats, device=device),
    }

    L = Lyapunov(nn.Sequential(normalizer, MLP([dim_state, 256, 256, 1])), env.barrier_fn, s0).to(device)

    if FLAGS.model.type == 'learned':
        model = ensemble
        U = EnsembleUncertainty(ensemble, L)
        if FLAGS.model.frozen:
            ensemble.requires_grad_(False)
            logger.warning(f"models are frozen!")
    else:
        assert 0
    safe_invariant = SafeInvariant(L, U, mean_policy)
    obj_eval = ObjEvaluator(safe_invariant)

    # set requires_grad to False so that SafeAlgo won't optimize them.
    if FLAGS.fix.L:
        L.requires_grad_(False)
    if FLAGS.fix.policy:
        policy.requires_grad_(False)
    if FLAGS.fix.model:
        ensemble.requires_grad_(False)

    # policy optimization
    fake_model_fns = {
        'transition': model,
        'reset': lambda: s0,
        # 'done': lambda s, a, sp: ~env.is_state_safe(sp),
        'done': lambda s, a, sp: torch.zeros(len(s), device=s.device, dtype=torch.bool),
        'reward': lambda s, a, sp: torch.where(env.is_state_safe(sp), env.reward_fn(s, a, sp),
                                               torch.tensor(-10000., device=s.device)),
    }
    model_runner = RunnerWithModel(fake_model_fns, horizon, dim_state, [EpisodeReturn], n=1, device=device)
    # model_eval_runner = ModelRunner(fake_model_fns, horizon, dim_state, batch_size=1, device=device)
    # buf_fake = rl_utils.TorchReplayBuffer(env, max_buf_size=1000_000)

    if FLAGS.ckpt.policy != '':  # must be done before define policy_optimizer (policy target init)
        policy.load_state_dict(torch.load(FLAGS.ckpt.policy, map_location=device)['policy'])
        logger.info(f"Load policy from {FLAGS.ckpt.policy}")
    if FLAGS.ckpt.L != '' and not FLAGS.use_fake_L:  # must be done before L_target is init
        L.load_state_dict(torch.load(FLAGS.ckpt.L, map_location=device)['L'])
        logger.info(f"Load L from {FLAGS.ckpt.L}")
        # assert FLAGS.should_optimize_policy
    if FLAGS.ckpt.models != '':
        model.load_state_dict(torch.load(FLAGS.ckpt.models, map_location=device)['models'])
        logger.info(f"Load model from {FLAGS.ckpt.models}")
    if FLAGS.ckpt.safe_invariant != '':
        safe_invariant.load_state_dict(torch.load(FLAGS.ckpt.safe_invariant, map_location=device)['safe_invariant'])
        logger.info(f"Load SafeInvariant from {FLAGS.ckpt.safe_invariant}")
    if FLAGS.ckpt.model_trainers != '':
        model_trainers_ckpt = torch.load(FLAGS.ckpt.model_trainers, map_location=device)['optimizer']
        for model_trainer, ckpt in zip(model_trainers, model_trainers_ckpt):
            model_trainer.load_state_dict(ckpt)

    safe_invariant_ref = deepcopy(safe_invariant)
    expl_policy = ExplorationPolicy(policy, safe_invariant_ref).to(device)

    hardD = lambda s: torch.where(L(s) <= 1, safe_invariant.U(s) - 1, -L(s) - 100)
    softD = lambda s: safe_invariant.U(s) - 1 - 100 * relu(L(s) - 1)

    policy_optimizer = SafeSACTrainer2(policy, [
            MLPQFn([dim_state + dim_action, 256, 256, 1]),
            MLPQFn([dim_state + dim_action, 256, 256, 1]),
        ], U,
        sampler=buf_real.sample,
        device=device,
        target_entropy=-dim_action,
    )

    L_opt = LOptimizer(dim_state, obj_eval, nn.ModuleList([L, policy]).parameters(), safe_invariant_ref.L).to(device)
    s_opt_langevin = SLangevinOptimizer(obj_eval, normalizer).to(device)
    s_opt_sample = SSampleOptimizer(dim_state, obj_eval).to(device)
    s_opt_grad = SGradOptimizer(dim_state, obj_eval, normalizer).to(device)
    s_opt = s_opt_langevin

    fns = {'L': L, 'U': U, 'hardD': hardD, 'softD': softD, 'logBarrier': lambda x: env.barrier_fn(x).log()}
    logger.debug(f"[normalizer]: mean = {normalizer.mean.cpu().numpy()}, std = {normalizer.std.cpu().numpy()}")

    buf_out = None
    debugger = Debugger(env, policy, mean_policy, L, model, runners['evaluate'], horizon, s0, s_opt, s_opt_grad,
                        s_opt_sample, L_opt, fns, obj_eval, buf_out, safe_invariant, FLAGS)

    if FLAGS.model.type == 'learned':
        ensemble.elites = list(range(FLAGS.model.n_elites))
        # ensemble.elites = [FLAGS.model.n_elites]

    if FLAGS.task == 'plot_policy_safe_region':
        plot_policy_safe_region(env.trans_fn, mean_policy, device)

    elif FLAGS.task == 'check_L':

        state = torch.tensor([-0.0542, -0.0988,  1.2449, -2.1043], device='cuda:0')
        breakpoint()
        debugger.evaluate(0, policy=True, mean_policy=True, virt_safe=True, s_grad=True, video=True)
        logger.info("pretrain s...")
        for i in range(100_000):
            if i % 1_000 == 0:
                # breakpoint()
                s_opt.evaluate(step=i)
            s_opt.step()

    elif FLAGS.task == 'pretrain_model':
        # expl_policy = UniformPolicy(dim_action)
        dev_infos = RunnerX(make_env, 10, device=device).run(rl_utils.policy.AddGaussianNoise(policy, 0, 0.4),
                                                             horizon * 10, buffer=buf_dev)

        if not FLAGS.ckpt.buf:
            RunnerX(make_env, 10, device=device).run(rl_utils.policy.UniformPolicy(dim_action), horizon * 10, buffer=buf_dev)
            runner = RunnerX(make_env, 1, stats=[EpisodeReturn, lambda: ExtractLastInfo('episode.unsafe')], device=device)

            for noise in np.linspace(0, 1.0, 500):
                print(noise)
                runner.reset()
                runner.run(rl_utils.policy.AddGaussianNoise(policy, 0, noise), 1 * horizon, buf_real)
            runner.reset()
            runner.run(rl_utils.policy.UniformPolicy(dim_action), 200 * horizon, buf_real)

        print('dev', merge_episode_stats(dev_infos))
        normalizer.fit(buf_real.state)

        # train_models(model_trainers, n_steps=50000)
        for model, trainer in zip(ensemble.models, model_trainers):
            trainer.fit(model, train_dataloader=buf_real.sampling_data_loader(1_000, 256),
                        val_dataloaders=buf_real.sampling_data_loader(1, 10_000))
            # model.test(mean_policy, s0)

        if not FLAGS.ckpt.buf:
            with open(PQ.log_dir / 'buf.pkl', 'wb') as f:
                pickle.dump(buf_real, f)
        torch.save({
            'models': ensemble.state_dict(),
        }, PQ.log_dir / 'final.pt')

    elif FLAGS.task == 'pretrain_L':
        debugger.evaluate(0, policy=True, mean_policy=True, virt_safe=True)
        logger.info("pretrain s...")
        for i in range(FLAGS.n_pretrain_s_iters):
            if i % 1_000 == 0:
                # breakpoint()
                s_opt.evaluate(step=i)
            s_opt.step()

        # logger.critical("zeroing the policy!")
        # for p in policy.parameters():
        #     nn.init.zeros_(p)
        L_opt.opt_params.param_groups[0]['params'] = list(L.parameters())
        L_opt.L_ref = None
        for t in range(FLAGS.n_iters):
            if t % 1_000 == 0:
                # breakpoint()
                logger.info(f"# iter {t}")
            debugger.evaluate(t, s=t % 1_000 == 0, video=t == 0, virt_safe=t % 10_000 == 0, save=t % 10_000 == 0,
                              plot=t % 5_000 == 0, s_grad=t % 50_000 == 0 and t > 0)
            for i in range(FLAGS.opt_s.n_steps):
                d = s_opt.step()['optimal']
                PQ.meters[f'opt_progress/{i}'] += d
            L_opt.step(s_opt.s)

    elif FLAGS.task == 'safe-init':
        model_trainers = [pl.Trainer(
            gradient_clip_val=10., max_epochs=20, gpus=-1, auto_select_gpus=True, default_root_dir=PQ.log_dir,
            progress_bar_refresh_rate=0, logger=wandb_logger)
            for _ in ensemble.models]

        # expl_policy = UniformPolicy(dim_action)
        dev_infos = RunnerX(make_env, 20, device=device).run(
            rl_utils.policy.AddGaussianNoise(policy, 0, 0.4), horizon * 50, buffer=buf_dev)

        if not FLAGS.ckpt.buf:
            RunnerX(make_env, 10, device=device).run(rl_utils.policy.UniformPolicy(dim_action), horizon * 10, buffer=buf_dev)
            runner = RunnerX(make_env, 1, stats=[EpisodeReturn, lambda: ExtractLastInfo('episode.unsafe')], device=device)

            for noise in np.linspace(0, 1, 1000):
                runner.reset()
                buf_tmp = rl_utils.TorchReplayBuffer(env, max_buf_size=1_000)
                ep_infos = runner.run(rl_utils.policy.AddGaussianNoise(policy, 0, noise), 1 * horizon, buf_tmp)
                if not ep_infos[0]['episode.unsafe']:
                    print(noise, ep_infos)
                    buf_real.add_transitions({
                        'state': buf_tmp.state,
                        'action': buf_tmp.action,
                        'next_state': buf_tmp.next_state,
                        'reward': buf_tmp.reward,
                        'done': buf_tmp.done,
                        'timeout': buf_tmp.timeout,
                    })
            # ep_infos = runner.run(rl_utils.policy.UniformPolicy(dim_action), 100 * horizon, buf_real)
            # print('rand', ep_infos)
            with open(PQ.log_dir / 'buf.pkl', 'wb') as f:
                pickle.dump(buf_real, f)

        print('dev', merge_episode_stats(dev_infos))
        normalizer.fit(buf_real.state)

        # train_models(model_trainers, n_steps=50000)
        for model, trainer in zip(ensemble.models, model_trainers):
            trainer.fit(model, train_dataloader=buf_real.sampling_data_loader(1_000, 256),
                        val_dataloaders=buf_dev.sampling_data_loader(1, 50_000))

        torch.save({
            'models': ensemble.state_dict(),
        }, PQ.log_dir / 'final.pt')

    elif FLAGS.task == 'unified-new-algo':
        logger.info("pretrain s...")
        for i in range(FLAGS.n_pretrain_s_iters):
            if i % 1000 == 0:
                s_opt.evaluate(step=i)
            s_opt.step()

        debugger.evaluate(0, video=True, plot=True)
        # collect 10_000 samples
        explore(0, runners, 10_000, expl_policy, buf_real)
        global_step = 0

        # train_models(model_trainers, n_steps=50000)
        for epoch in range(50):
            model_trainers = [pl.Trainer(
                gradient_clip_val=10., max_epochs=1, gpus=1, auto_select_gpus=True, default_root_dir=PQ.log_dir,
                progress_bar_refresh_rate=0, logger=wandb_logger)
                for _ in ensemble.models]

            for model, trainer in zip(ensemble.models, model_trainers):
                trainer.fit(model, train_dataloader=buf_real.sampling_data_loader(1_000, 256))
                model.to(device)  # pytorch lightning transferred the model to cpu
                # see pytorch_lightning/trainer/training_loop.py:220

            for _ in range(1_000):
                for i in range(FLAGS.opt_s.n_steps):
                    s_opt.step()

            # train policy
            logger.info(f"Epoch {epoch}: train policy")
            policy_optimizer.can_update_policy = True
            global_step = bump(global_step, 1000_000_000)
            # debugger.evaluate(global_step, video=True, plot=True)
            L_opt.opt_params.param_groups[0]['params'] = list(policy.parameters())   # + list(L.parameters())
            debugger.evaluate(epoch * 2000, mean_policy=True)
            for t in range(2_000):
                global_step = bump(global_step, 1)
                if t % 1000 == 0:
                    debugger.evaluate(global_step, s=True, video=True, expl=True, virt_safe=True, plot=True)
                if t % 1000 == 0:
                    explore(t, runners, horizon, expl_policy, buf_real)
                if t >= 1000 or epoch > 0:
                    policy_optimizer.step()   # optimize unsafe policy

                for i in range(FLAGS.opt_s.n_steps):  # opt s
                    s_opt.step()
                L_opt.step(s_opt.s, should_update=True)
            global_step = bump(global_step, 1)
            debugger.evaluate(global_step, video=True, plot=True)

            # check if virt_safe
            global_step = bump(global_step, 1_000_000)
            debugger.evaluate(global_step, virt_safe=True)

            # train L
            logger.info(f"Epoch {epoch}: train L!")
            L_opt.opt_params.param_groups[0]['params'] = list(L.parameters())  # + list(policy.parameters())
            win_streak = 0
            global_step = bump(global_step, 1_000_000)
            total_updates = 0

            for t in range(10_000):
                global_step = bump(global_step, 1)
                if t % 1_000 == 0:
                    logger.info(f"# iter {t}")
                debugger.evaluate(global_step, s=t % 1_000 == 0 and t != 0,
                                  virt_safe=t % 10_000 == 0, plot=t % 1_000 == 0)
                for i in range(FLAGS.opt_s.n_steps):
                    s_opt.step()
                result = L_opt.step(s_opt.s)
                if result['max_obj'] <= 0.0:
                    win_streak += 1
                else:
                    total_updates += 1
                    win_streak = 0

                if t == 1_000 and total_updates / t > 0.99:
                    L.load_state_dict(safe_invariant_ref.barrier.state_dict())
                    logger.critical("early stop... unlikely to find an L")
                    break

                if win_streak == 1000:
                    safe_invariant_ref.load_state_dict(safe_invariant.state_dict())
                    logger.info(f"win streak at {t} => find a new invariant => update ref. ")
                    break
            else:
                L.load_state_dict(safe_invariant_ref.barrier.state_dict())
                logger.warning("can't find L")
            global_step = bump(global_step, 1)
            debugger.evaluate(epoch, s=True, virt_safe=True, save=True, plot=True)
    else:
        assert 0, f"invalid task `{FLAGS.task}`"


if __name__ == '__main__':
    main()
