# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
import argparse
import gym
import time
import json
import random
import csv
import utils
from logger import Logger
from video import VideoRecorder

from agent.vlm_agent_carla_2 import VLM_Agent
from agent.my_agent_carla import MYAgent
from carla_env.carla_env_9_6 import CarlaEnv
import os
import warnings
warnings.filterwarnings('ignore')

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def parse_args():
    parser = argparse.ArgumentParser()
    # environment
    parser.add_argument('--domain_name', default='carla', choices=['carla'])
    parser.add_argument('--task_name', default='run')
    parser.add_argument('--scenarios', default='highway', choices=['highway', 'ghost_static', 'shelter_car'])
    parser.add_argument('--image_size', default=128, type=int)
    parser.add_argument('--action_repeat', default=1, type=int)
    parser.add_argument('--frame_stack', default=3, type=int)
    parser.add_argument('--resource_files', type=str)
    parser.add_argument('--eval_resource_files', type=str)
    parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
    parser.add_argument('--total_frames', default=1000, type=int)
    parser.add_argument('--loss_type', default='value', type=str, choices=['value', 'loss', 'dist', 'anneal'])

    # replay buffer
    parser.add_argument('--replay_buffer_capacity', default=100000, type=int)
    # train
    parser.add_argument('--agent', default='deepmdp', type=str, choices=['baseline', 'llm', 'deepmdp'])
    parser.add_argument('--init_steps', default=1000, type=int)
    parser.add_argument('--num_train_steps', default=1000000, type=int)
    parser.add_argument('--batch_size', default=512, type=int)
    parser.add_argument('--hidden_dim', default=256, type=int)
    parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
    parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms')
    parser.add_argument('--load_encoder', default=None, type=str)
    parser.add_argument('--vlm_freq', default=60, type=int)
    parser.add_argument('--load_critic_best', default=1, type=int)
    parser.add_argument('--critic_best_path', default=None, type=str)

    # eval
    parser.add_argument('--eval_freq', default=10, type=int)  # TODO: master had 10000
    parser.add_argument('--num_eval_episodes', default=20, type=int)
    # critic
    parser.add_argument('--critic_lr', default=1e-3, type=float)
    parser.add_argument('--critic_beta', default=0.9, type=float)
    parser.add_argument('--critic_tau', default=0.005, type=float)
    parser.add_argument('--critic_target_update_freq', default=2, type=int)
    # actor
    parser.add_argument('--actor_lr', default=1e-3, type=float)
    parser.add_argument('--actor_beta', default=0.9, type=float)
    parser.add_argument('--actor_log_std_min', default=-10, type=float)
    parser.add_argument('--actor_log_std_max', default=2, type=float)
    parser.add_argument('--actor_update_freq', default=2, type=int)
    # encoder/decoder
    parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla', 'identity'])
    parser.add_argument('--encoder_feature_dim', default=50, type=int)
    parser.add_argument('--encoder_lr', default=1e-3, type=float)
    parser.add_argument('--encoder_tau', default=0.005, type=float)
    parser.add_argument('--encoder_stride', default=1, type=int)
    parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity'])
    parser.add_argument('--decoder_lr', default=1e-3, type=float)
    parser.add_argument('--decoder_update_freq', default=1, type=int)
    parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
    parser.add_argument('--num_layers', default=4, type=int)
    parser.add_argument('--num_filters', default=32, type=int)
    # sac
    parser.add_argument('--discount', default=0.99, type=float)
    parser.add_argument('--init_temperature', default=0.01, type=float)
    parser.add_argument('--alpha_lr', default=1e-3, type=float)
    parser.add_argument('--alpha_beta', default=0.9, type=float)
    parser.add_argument('--latent_dim', default=128, type=int)
    # misc
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--work_dir', default='./log', type=str)
    parser.add_argument('--save_tb', default=False, action='store_true')
    parser.add_argument('--save_model', default=False, action='store_true')
    parser.add_argument('--save_buffer', default=False, action='store_true')
    parser.add_argument('--save_video', default=False, action='store_true')
    parser.add_argument('--render', default=False, action='store_true')
    parser.add_argument('--detach_encoder', default=False, action='store_true')
    parser.add_argument('--transition_model_type', default='', type=str,
                        choices=['', 'deterministic', 'probabilistic', 'ensemble'])
    parser.add_argument('--port', default=2000, type=int)
    parser.add_argument('--trafficManagerPort', default=8000, type=int)
    # data augs
    parser.add_argument('--data_augs', default='crop', type=str)
    parser.add_argument('--log_interval', default=100, type=int)
    args = parser.parse_args()
    return args


def evaluate(env, args, agent, video, num_episodes, L, step, device=None, embed_viz_dir=None, do_carla_metrics=False,
             res_dir="", model_dir=""):
    # carla metrics:
    crash_intensity = 0.
    steer = 0.
    brake = 0.
    count = 0
    all_ep_rewards = []
    all_steers = []
    all_brakes = []
    reason_each_episode_ended = []
    distance_driven_each_episode = []
    all_crash_intensity = []
    agent.load_best(model_dir)
    video_step = 0
    for i in range(num_episodes):
        # carla metrics:
        dist_driven_this_episode = 0.

        obs = env.reset()
        pre_obs = obs
        pre_action = [0,0]
        video.init(enabled=True)
        done = False
        episode_reward = 0
        while not done:
            with utils.eval_mode(agent):
                # obs: (9, 84, 84 * num_cameras)
                action = agent.select_action(obs, pre_obs, pre_action)

            next_obs, reward, done, info = env.step(action)
            pre_action = action
            pre_obs = obs
            obs = next_obs

            # metrics:
            if do_carla_metrics:
                dist_driven_this_episode += info['distance']
                crash_intensity += info['crash_intensity']
                steer += abs(info['steer'])
                brake += info['brake']
                count += 1

            video.record(env)
            episode_reward += reward
            video_step = video_step + 1

        # metrics:
        if do_carla_metrics:
            reason_each_episode_ended.append(info['reason_each_episode_ended'])
            distance_driven_each_episode.append(dist_driven_this_episode)
            all_steers.append(steer)
            all_brakes.append(brake)
            all_crash_intensity.append(crash_intensity)
            eval_log_filepath = os.path.join(res_dir, "{}.txt".format(args.seed))
            eval_log_txt_formatter = "{step},{distance_driven_each_episode},{steer},{brake}, {crash_intensity},{episode_reward},{reason}\n"
            to_write = eval_log_txt_formatter.format(step=step,
                                                     steer=steer,
                                                     brake=brake,
                                                     crash_intensity=crash_intensity,
                                                     distance_driven_each_episode=dist_driven_this_episode,
                                                     episode_reward=episode_reward,
                                                     reason=info['reason_each_episode_ended'])

            with open(eval_log_filepath, "a") as f:
                f.write(to_write)

        all_ep_rewards.append(episode_reward)
        video.save('%d.mp4' % step)
        video_step = 0

    mean_ep_reward = np.mean(all_ep_rewards)
    best_ep_reward = np.max(all_ep_rewards)
    std_ep_reward = np.std(all_ep_rewards)
    print("mean_ep_reward:", mean_ep_reward)
    print("std_ep_reward:", std_ep_reward)
    print("best_ep_reward:", best_ep_reward)
    print("mean_distance_driven_episode:", np.mean(distance_driven_each_episode))
    print("mean_steer_episode:", np.mean(all_steers))
    print("mean_brake_episode:", np.mean(all_brakes))
    print("mean_crash_intensity_episode:", np.mean(all_crash_intensity))

    L.log('eval/mean_ep_reward', mean_ep_reward, step)
    L.log('eval/std_ep_reward', std_ep_reward, step)
    L.log('eval/best_ep_reward', best_ep_reward, step)
    L.log('eval/mean_distance_driven_episode', np.mean(distance_driven_each_episode), step)
    L.log('eval/mean_steer_episode', np.mean(all_steers), step)
    L.log('eval/mean_brake_episode', np.mean(all_brakes), step)

    L.dump(step)

    # if do_carla_metrics:
    #     print('METRICS--------------------------')
    #     print("reason_each_episode_ended: {}".format(reason_each_episode_ended))
    #     print("distance_driven_each_episode: {}".format(distance_driven_each_episode))
    #     print('crash_intensity: {}'.format(crash_intensity / num_episodes))
    #     print('steer: {}'.format(steer / count))
    #     print('brake: {}'.format(brake / count))
    #     print('---------------------------------')


def append_to_csv(step, episode_reward, filename='episode_rewards.csv'):
    """
    将 step 和 episode_reward 追加写入到 CSV 文件中。

    参数:
        step (int): 当前的步数。
        episode_reward (float): 当前 episode 的奖励。
        filename (str): CSV 文件名，默认为 'episode_rewards.csv'。
    """
    # 检查文件是否存在，如果不存在则创建并写入表头
    if not os.path.exists(filename):
        with open(filename, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['Step', 'Value'])  # 写入表头

    # 以追加模式写入数据
    with open(filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([step, episode_reward])  # 写入数据


def make_agent(obs_shape, action_shape, args, device):
    # take deepmdp as base
    agent = MYAgent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        device=device,
        hidden_dim=args.hidden_dim,
        discount=args.discount,
        init_temperature=args.init_temperature,
        alpha_lr=args.alpha_lr,
        alpha_beta=args.alpha_beta,
        actor_lr=args.actor_lr,
        actor_beta=args.actor_beta,
        actor_log_std_min=args.actor_log_std_min,
        actor_log_std_max=args.actor_log_std_max,
        actor_update_freq=args.actor_update_freq,
        encoder_stride=args.encoder_stride,
        critic_lr=args.critic_lr,
        critic_beta=args.critic_beta,
        critic_tau=args.critic_tau,
        critic_target_update_freq=args.critic_target_update_freq,
        encoder_type=args.encoder_type,
        encoder_feature_dim=args.encoder_feature_dim,
        encoder_lr=args.encoder_lr,
        encoder_tau=args.encoder_tau,
        decoder_type=args.decoder_type,
        decoder_lr=args.decoder_lr,
        decoder_update_freq=args.decoder_update_freq,
        decoder_weight_lambda=args.decoder_weight_lambda,
        transition_model_type=args.transition_model_type,
        num_layers=args.num_layers,
        num_filters=args.num_filters,
        data_augs=args.data_augs
    )
    # resume
    if args.load_encoder:
        model_dict = agent.actor.encoder.state_dict()
        encoder_dict = torch.load(args.load_encoder)
        encoder_dict = {k[8:]: v for k, v in encoder_dict.items() if 'encoder.' in k}  # hack to remove encoder. string
        agent.actor.encoder.load_state_dict(encoder_dict)
        agent.critic.encoder.load_state_dict(encoder_dict)

    if args.load_critic_best == 1:
        print("load critic.................")
        print(args.critic_best_path)
        # "/home/xcm/LLM_RL/save/rl-only/carla_HighwayLimit_seed:2_new_2024-10-09-11-08-37/model"
        agent.load_critic(args.critic_best_path)
    print("--------------------loss_type:", args.loss_type)

    return agent


def main():
    args = parse_args()
    utils.set_seed_everywhere(args.seed)
    if args.scenarios == "ghost_static" or args.scenarios == "shelter_car":
        max_episode_steps = 1000
    else:
        max_episode_steps = 1000
    env = CarlaEnv(
        render_display=args.render,  # for local debugging only
        display_text=args.render,  # for local debugging only
        changing_weather_speed=0.1,  # [0, +inf)
        rl_image_size=args.image_size,
        max_episode_steps=max_episode_steps,
        frame_skip=args.action_repeat,
        port=args.port,
        trafficManagerPort=args.trafficManagerPort,
        scenarios=args.scenarios,
        vlm_use=True,
        vrc_use=True
    )
    # TODO: implement env.seed(args.seed) ?
    eval_env = env

    # stack several consecutive frames together
    if args.encoder_type.startswith('pixel'):
        env = utils.FrameStack(env, k=args.frame_stack)
        eval_env = utils.FrameStack(eval_env, k=args.frame_stack)

    # add time
    work_dir = args.work_dir + '_' + "seed:{}_".format(args.seed) + time.strftime("%Y-%m-%d-%H-%M-%S")
    utils.make_dir(work_dir)
    video_dir = utils.make_dir(os.path.join(work_dir, 'video'))
    model_dir = utils.make_dir(os.path.join(work_dir, 'model'))
    buffer_dir = utils.make_dir(os.path.join(work_dir, 'buffer'))
    res_dir = utils.make_dir(os.path.join(work_dir, 'res_dir'))

    video = VideoRecorder(video_dir if args.save_video else None)

    with open(os.path.join(work_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    mask_feature = np.random.rand(56, 56)
    if args.loss_type != "value":
        VLM_agent = VLM_Agent(args.scenarios, device, res_dir)
        if args.loss_type == "loss" or args.loss_type == "anneal":
            VLM_init = [0, 1.0]
        else:
            VLM_init = "No steering change and keep throttle=1"
    else:
        VLM_agent = None
        VLM_init = [0, 1.0]

    #
    # cris
    # cfg = config.load_cfg_from_cfg_file("/home/user/cg/LLM_RL/config/refcoco/cris_r50.yaml")
    # cris_model, _ = build_segmenter(cfg)
    # cris_model = torch.nn.DataParallel(cris_model).cuda()
    #
    # cris_model_dir = "/home/user/cg/LLM_RL/exp/refcoco/CRIS_R50/best_model.pth"
    # if os.path.isfile(cris_model_dir):
    #     checkpoint = torch.load(cris_model_dir)
    #     cris_model.load_state_dict(checkpoint['state_dict'], strict=True)
    # else:
    #     raise ValueError(
    #         "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
    #             .format(cris_model_dir))

    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    replay_buffer = utils.ReplayBuffer(
        obs_shape=env.observation_space.shape,
        action_shape=env.action_space.shape,
        capacity=args.replay_buffer_capacity,
        batch_size=args.batch_size,
        device=device,
        pre_image_size=args.image_size,
        use_loss=args.loss_type
    )

    agent = make_agent(
        obs_shape=env.observation_space.shape,
        action_shape=env.action_space.shape,
        args=args,
        device=device
    )
    L = Logger(work_dir, use_tb=args.save_tb)

    # train agent
    explore_times = [50, 100, 150, 200, 250, 300]
    explore_times_step = 0
    current_option = 0
    use_option = False
    #
    episode, episode_reward, done = 0, 0, True
    start_time = time.time()
    for step in range(args.num_train_steps):
        if done:
            if args.decoder_type == 'inverse':
                for i in range(1, args.k):  # fill k_obs with 0s if episode is done
                    replay_buffer.k_obses[replay_buffer.idx - i] = 0
            if step > 0:
                L.log('train/duration', time.time() - start_time, step)
                start_time = time.time()
                L.dump(step)

            # evaluate agent periodically
            # if episode % 100 == 0 and episode != 0:
            #     evaluate(eval_env, args, agent, video, 1, L, step, do_carla_metrics=True,
            #              res_dir=res_dir, model_dir=model_dir)

            if args.save_model:
                agent.save_best(model_dir, episode_reward)

            L.log('train/episode_reward', episode_reward, step)

            # train_log_filepath = os.path.join(res_dir, "train_log.txt")
            # train_log_txt_formatter = "{step}:{episode_reward}\n"
            # to_write = train_log_txt_formatter.format(step=step, episode_reward=episode_reward)
            # with open(train_log_filepath, "a") as f:
            #     f.write(to_write)
            filepath = os.path.join(res_dir, "{}.csv".format(args.seed))
            append_to_csv(step, episode_reward, filepath)

            # if step > args.init_steps:
            #     print('------------------episode_step:{}-------------------'.format(episode_step))
            obs = env.reset()
            pre_obs = obs
            pre_action = VLM_init
            done = False
            use_option = False
            explore_times_step = 0
            info = None
            VLM_action = VLM_init
            current_option = 0
            episode_reward = 0
            episode_step = 0
            episode += 1
            reward = 0
            xx = random.choice([0, 1, 2, 3, 4, 5])

            L.log('train/episode', episode, step)

        # sample action for data collection
        if step < args.init_steps:
            RL_action = env.action_space.sample()  # random
            if args.loss_type != "value":
                with torch.no_grad():
                    if episode_step != 0 and step % args.vlm_freq == 0:
                        # print("------------------Warming Access VLM-------------------")
                        vlm_obs = [info['vlm_rgb'], info['selected_ego_velocity'], info['ego_orientation']]
                        # VLM_action, objects = VLM_agent.select_action(vlm_obs, episode_step, args.seed)
                        VLM_action = VLM_agent.select_action(vlm_obs, episode_step, args.seed)
                        # mask_feature = inference_demo(cris_model, info['vlm_rgb'], cfg, device)
                    elif episode_step == 0:
                        VLM_action = VLM_init
                    else:
                        VLM_action = VLM_action

            RL_action = VLM_action
            # if explore_times_step % explore_times[random.choice([0, 1, 2])] == 0:
            #     current_option = 1 - current_option
            # explore_times_step = explore_times_step + 1
            #
            # if current_option == 0:
            #     RL_action = VLM_action
            # else:
            #     pass
        else:
            if args.loss_type != "value":
                with torch.no_grad():
                    if info is not None and (episode_step == 1 or episode_step % args.vlm_freq == 0):
                        # print("------------------Access VLM-------------------")
                        vlm_obs = [info['vlm_rgb'], info['selected_ego_velocity'], info['ego_orientation']]
                        # VLM_action, objects = VLM_agent.select_action(vlm_obs, episode_step, args.seed)
                        VLM_action = VLM_agent.select_action(vlm_obs, episode_step, args.seed)
                        # mask_feature = inference_demo(cris_model, info['vlm_rgb'], cfg, device)
                    else:
                        # keep VLM_action unchange
                        VLM_action = VLM_action

            with utils.eval_mode(agent):
                RL_action = agent.sample_action(obs, pre_obs, pre_action, VLM_action)
                # RL_action, pro_action = agent.sample_action(obs, pre_obs, pre_action, VLM_action)
                # L.log('train/pro_action', pro_action[0], step)

            # if step > 50000 and np.random.rand() < agent.epsilon:
            #     use_option = True
            #
            # if use_option:
            #     if explore_times_step % explore_times[xx] == 0:
            #         current_option = 1 - current_option
            #     explore_times_step = explore_times_step + 1
            #
            #     if current_option == 0:
            #         RL_action = VLM_action
            #     else:
            #         pass

        # run training update
        if step >= args.init_steps:
            num_updates = args.init_steps if step == args.init_steps else 1
            for _ in range(num_updates):
                agent.update(replay_buffer, L, step, Loss_type=args.loss_type, vlm_update_freq=args.vlm_freq)

        curr_reward = reward
        next_obs, reward, done, info = env.step(RL_action)

        # allow infinit bootstrap
        done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
            done
        )
        episode_reward += reward

        replay_buffer.add_pre(obs, RL_action, VLM_action, pre_action, curr_reward, reward, next_obs, done_bool, pre_obs)
        # update current obs
        pre_obs = obs
        pre_action = RL_action
        obs = next_obs
        episode_step += 1

    # evaluate agent
    agent.save(model_dir, step)
    # print('----------------start eval-----------------')
    evaluate(eval_env, args, agent, video, args.num_eval_episodes, L, step, do_carla_metrics=True,
             res_dir=res_dir, model_dir=model_dir)


if __name__ == '__main__':
    main()
