import copy
from functools import partial

from ding.envs import SyncSubprocessEnvManager, BaseEnvManager
from ding.policy import DQNPolicy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner, SampleSerialCollector, AdvancedReplayBuffer
from easydict import EasyDict
from tensorboardX import SummaryWriter

from envs.di_drive.core.envs import CarlaEnvWrapper
from envs.di_drive.core.eval import SerialEvaluator
from envs.di_drive.core.utils.data_utils.bev_utils import unpack_birdview
from envs.di_drive.core.utils.others.ding_utils import compile_config
from envs.di_drive.core.utils.others.tcp_helper import parse_carla_tcp
from envs.di_drive.demo.latent_rl.latent_rl_env import CarlaLatentRLEnv
from envs.di_drive.demo.latent_rl.model import LatentDQNRLModel

train_config = dict(
    exp_name='latentdqn_buf2e5_lr1e4_bs64_ns1000_update10_train_ft',
    env=dict(
        collector_env_num=7,
        evaluator_env_num=1,
        simulator=dict(
            town='Town01',
            disable_two_wheels=True,
            verbose=False,
            planner=dict(
                type='lbc',
                resolution=1.0,
                threshold_before=3.0,
                threshold_after=3.0,
            ),
            obs=(dict(
                name='birdview',
                type='bev',
                size=[320, 320],
                pixels_per_meter=5,
                pixels_ahead_vehicle=100,
            ), )
        ),
        manager=dict(collect=dict(
            auto_reset=True,
            shared_memory=False,
            context='spawn',
            max_retry=1,
        ), eval=dict()),
        wrapper=dict(
            collect=dict(suite='train_ft', ),
            eval=dict(suite='FullTown02-v1', ),
        ),
    ),
    server=[
        dict(carla_host='localhost', carla_ports=[9000, 9016, 2]),
    ],
    policy=dict(
        cuda=True,
        priority=True,
        nstep=3,
        discount_factor=0.99,
        model=dict(action_shape=100, ),
        learn=dict(
            batch_size=64,
            learning_rate=0.0001,
            weight_decay=0.0001,
            target_update_freq=1000,
            learner=dict(hook=dict(load_ckpt_before_run='', ), ),
        ),
        collect=dict(
            n_sample=1000,
            collector=dict(
                collect_print_freq=1000,
                deepcopy_obs=True,
                transform_obs=True,
            ),
        ),
        eval=dict(evaluator=dict(
            eval_freq=5000,
            n_episode=3,
            stop_rate=0.7,
            transform_obs=True,
        ), ),
        other=dict(
            eps=dict(
                type='exp',
                start=0.95,
                end=0.1,
                decay=30000,
            ),
            replay_buffer=dict(
                replay_buffer_size=100000,
                # max_use=100,
                monitor=dict(
                    sampled_data_attr=dict(print_freq=100, ),
                    periodic_thruput=dict(seconds=120, ),
                ),
            ),
        ),
    ),
)

main_config = EasyDict(train_config)


def wrapped_env(env_cfg, wrapper_cfg, host, port, tm_port=None):
    return CarlaEnvWrapper(CarlaLatentRLEnv(env_cfg, host, port, tm_port), wrapper_cfg)


def main(cfg, seed=0):
    cfg = compile_config(
        cfg,
        SyncSubprocessEnvManager,
        DQNPolicy,
        BaseLearner,
        SampleSerialCollector,
        buffer=AdvancedReplayBuffer,
    )
    tcp_list = parse_carla_tcp(cfg.server)
    collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
    assert len(tcp_list) >= collector_env_num + evaluator_env_num, \
        "Carla server not enough! Need {} servers but only found {}.".format(
            collector_env_num + evaluator_env_num, len(tcp_list)
        )

    collector_env = SyncSubprocessEnvManager(
        env_fn=[partial(wrapped_env, cfg.env, cfg.env.wrapper.collect, *tcp_list[i]) for i in range(collector_env_num)],
        cfg=cfg.env.manager.collect,
    )
    evaluate_env = BaseEnvManager(
        env_fn=[
            partial(wrapped_env, cfg.env, cfg.env.wrapper.eval, *tcp_list[collector_env_num + i])
            for i in range(evaluator_env_num)
        ],
        cfg=cfg.env.manager.eval,
    )
    collector_env.seed(seed)
    evaluate_env.seed(seed)
    set_pkg_seed(seed)

    model = LatentDQNRLModel(**cfg.policy.model)
    policy = DQNPolicy(cfg.policy, model=model)

    tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
    learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
    collector = SampleSerialCollector(
        cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
    )
    evaluator = SerialEvaluator(
        cfg.policy.eval.evaluator, evaluate_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
    )
    replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)

    # Set up other modules, etc. epsilon greedy
    eps_cfg = cfg.policy.other.eps
    epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)

    learner.call_hook('before_run')
    eps = epsilon_greedy(learner.train_iter)
    new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
    replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)

    while True:
        if evaluator.should_eval(learner.train_iter):
            stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
            if stop:
                break
        eps = epsilon_greedy(learner.train_iter)
        # Sampling data from environments
        new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
        update_per_collect = len(new_data) // cfg.policy.learn.batch_size * 10
        replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
        # Training
        for i in range(update_per_collect):
            train_data = replay_buffer.sample(cfg.policy.learn.batch_size, learner.train_iter)
            if train_data is not None:
                train_data = copy.deepcopy(train_data)
                unpack_birdview(train_data)
                learner.train(train_data, collector.envstep)
            replay_buffer.update(learner.priority_info)

    learner.call_hook('after_run')

    collector.close()
    evaluator.close()
    learner.close()
    replay_buffer.close()


if __name__ == '__main__':
    main(main_config)
