import copy
from functools import partial

from ding.envs import SyncSubprocessEnvManager, BaseEnvManager
from ding.policy import SACPolicy
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner, SampleSerialCollector, NaiveReplayBuffer
from easydict import EasyDict
from tensorboardX import SummaryWriter

from envs.di_drive.core.envs import SimpleCarlaEnv
from envs.di_drive.core.eval import SerialEvaluator
from envs.di_drive.core.utils.data_utils.bev_utils import unpack_birdview
from envs.di_drive.core.utils.others.ding_utils import compile_config
from envs.di_drive.core.utils.others.tcp_helper import parse_carla_tcp
from envs.di_drive.demo.simple_rl.env_wrapper import ContinuousBenchmarkEnvWrapper
from envs.di_drive.demo.simple_rl.model import SACRLModel

train_config = dict(
    exp_name='sac2_bev32_buf2e5_lr1e4_bs128_ns3000_update4_train_ft',
    env=dict(
        collector_env_num=1,
        evaluator_env_num=1,
        simulator=dict(
            town='Town01',
            disable_two_wheels=True,
            verbose=False,
            waypoint_num=32,
            planner=dict(
                type='behavior',
                resolution=1,
            ),
            obs=(dict(
                name='birdview',
                type='bev',
                size=[32, 32],
                pixels_per_meter=1,
                pixels_ahead_vehicle=14,
            ), )
        ),
        col_is_failure=True,
        stuck_is_failure=True,
        wrong_direction_is_failure=False,
        off_route_is_failure=False,
        off_road_is_failure=True,
        ignore_light=True,
        replay_path='./sac_video',
        visualize=dict(type='birdview', ),
        manager=dict(collect=dict(
            auto_reset=True,
            shared_memory=False,
            context='spawn',
            max_retry=1,
        ), eval=dict()),
        wrapper=dict(
            # Collect and eval suites for training
            collect=dict(suite='train_ft', ),
            eval=dict(suite='FullTown02-v1', ),
        ),
    ),
    server=[
        dict(carla_host='localhost', carla_ports=[9012, 9016, 2]),
    ],
    policy=dict(
        cuda=True,
        model=dict(action_shape=2, twin_critic=True),
        learn=dict(
            batch_size=256,
            learning_rate_q=1e-4,
            learning_rate_policy=1e-4,
            learning_rate_value=1e-4,
            learning_rate_alpha=1e-4,
            weight_decay=0.0001,
            learner=dict(hook=dict(load_ckpt_before_run='', ), ),
        ),
        collect=dict(
            n_sample=3000,
            noise_sigma=0.1,
            collector=dict(
                collect_print_freq=500,
                deepcopy_obs=True,
                transform_obs=True,
            ),
        ),
        eval=dict(evaluator=dict(
            eval_freq=500,
            n_episode=20,
            stop_rate=0.7,
            transform_obs=True,
        ), ),
        other=dict(
            replay_buffer=dict(
                replay_buffer_size=400000,
                replay_buffer_start_size=10000,
                max_use=16,
                monitor=dict(
                    sampled_data_attr=dict(
                        print_freq=100,  # times
                    ),
                    periodic_thruput=dict(seconds=120, ),
                ),
            ),
        ),
    ),
)

main_config = EasyDict(train_config)


def wrapped_env(env_cfg, wrapper_cfg, host, port, tm_port=None):
    return ContinuousBenchmarkEnvWrapper(SimpleCarlaEnv(env_cfg, host, port, tm_port), wrapper_cfg)


def main(cfg, seed=0):
    cfg = compile_config(
        cfg,
        SyncSubprocessEnvManager,
        SACPolicy,
        BaseLearner,
        SampleSerialCollector,
        buffer=NaiveReplayBuffer,
    )
    tcp_list = parse_carla_tcp(cfg.server)
    collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
    assert len(tcp_list) >= collector_env_num + evaluator_env_num, \
        "Carla server not enough! Need {} servers but only found {}.".format(
            collector_env_num + evaluator_env_num, len(tcp_list)
        )

    collector_env = SyncSubprocessEnvManager(
        env_fn=[partial(wrapped_env, cfg.env, cfg.env.wrapper.collect, *tcp_list[i]) for i in range(collector_env_num)],
        cfg=cfg.env.manager.collect,
    )
    evaluate_env = BaseEnvManager(
        env_fn=[
            partial(wrapped_env, cfg.env, cfg.env.wrapper.eval, *tcp_list[collector_env_num + i])
            for i in range(evaluator_env_num)
        ],
        cfg=cfg.env.manager.eval,
    )
    # Uncomment this to add save replay when evaluation
    # evaluate_env.enable_save_replay(cfg.env.replay_path)
    collector_env.seed(seed)
    evaluate_env.seed(seed)
    set_pkg_seed(seed)

    model = SACRLModel(**cfg.policy.model)
    policy = SACPolicy(cfg.policy, model=model)

    tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
    learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
    collector = SampleSerialCollector(
        cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
    )
    evaluator = SerialEvaluator(
        cfg.policy.eval.evaluator, evaluate_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
    )
    replay_buffer = NaiveReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)

    learner.call_hook('before_run')

    new_data = collector.collect(n_sample=10000, train_iter=learner.train_iter)
    replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)

    while True:
        if evaluator.should_eval(learner.train_iter):
            stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
            if stop:
                break
        # Sampling data from environments
        new_data = collector.collect(train_iter=learner.train_iter)
        update_per_collect = len(new_data) // cfg.policy.learn.batch_size * 4
        replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
        # Training
        for i in range(update_per_collect):
            train_data = replay_buffer.sample(cfg.policy.learn.batch_size, learner.train_iter)
            if train_data is not None:
                train_data = copy.deepcopy(train_data)
                unpack_birdview(train_data)
                learner.train(train_data, collector.envstep)

    learner.call_hook('after_run')

    collector.close()
    evaluator.close()
    learner.close()
    replay_buffer.close()


if __name__ == '__main__':
    main(main_config)
