# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
import argparse
import gym
import time
import json
import random

import utils
from logger import Logger
from video import VideoRecorder

from agent.vlm_agent_carla_generalization import VLM_Agent
from agent.my_agent_base import MYAgent
from carla_env.carla_env_5 import CarlaEnv
import os
import warnings
warnings.filterwarnings('ignore')

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def parse_args():
    parser = argparse.ArgumentParser()
    # environment
    parser.add_argument('--domain_name', default='carla', choices=['carla'])
    parser.add_argument('--task_name', default='run')
    parser.add_argument('--scenarios', default='highway', choices=['highway', 'ghost_static', 'shelter_car'])
    parser.add_argument('--image_size', default=128, type=int)
    parser.add_argument('--action_repeat', default=1, type=int)
    parser.add_argument('--frame_stack', default=3, type=int)
    parser.add_argument('--resource_files', type=str)
    parser.add_argument('--eval_resource_files', type=str)
    parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
    parser.add_argument('--total_frames', default=1000, type=int)
    parser.add_argument('--loss_type', default='value', type=str, choices=['value', 'loss', 'dist', 'anneal'])

    # replay buffer
    parser.add_argument('--replay_buffer_capacity', default=100000, type=int)
    # train
    parser.add_argument('--agent', default='deepmdp', type=str, choices=['baseline', 'llm', 'deepmdp'])
    parser.add_argument('--init_steps', default=1000, type=int)
    parser.add_argument('--num_train_steps', default=1000000, type=int)
    parser.add_argument('--batch_size', default=512, type=int)
    parser.add_argument('--hidden_dim', default=256, type=int)
    parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
    parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms')
    parser.add_argument('--load_encoder', default=None, type=str)
    parser.add_argument('--vlm_freq', default=60, type=int)
    parser.add_argument('--load_critic_best', default=1, type=int)
    parser.add_argument('--critic_best_path', default=None, type=str)

    # eval
    parser.add_argument('--eval_freq', default=10, type=int)  # TODO: master had 10000
    parser.add_argument('--num_eval_episodes', default=20, type=int)
    # critic
    parser.add_argument('--critic_lr', default=1e-3, type=float)
    parser.add_argument('--critic_beta', default=0.9, type=float)
    parser.add_argument('--critic_tau', default=0.005, type=float)
    parser.add_argument('--critic_target_update_freq', default=2, type=int)
    # actor
    parser.add_argument('--actor_lr', default=1e-3, type=float)
    parser.add_argument('--actor_beta', default=0.9, type=float)
    parser.add_argument('--actor_log_std_min', default=-10, type=float)
    parser.add_argument('--actor_log_std_max', default=2, type=float)
    parser.add_argument('--actor_update_freq', default=2, type=int)
    # encoder/decoder
    parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla', 'identity'])
    parser.add_argument('--encoder_feature_dim', default=50, type=int)
    parser.add_argument('--encoder_lr', default=1e-3, type=float)
    parser.add_argument('--encoder_tau', default=0.005, type=float)
    parser.add_argument('--encoder_stride', default=1, type=int)
    parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity'])
    parser.add_argument('--decoder_lr', default=1e-3, type=float)
    parser.add_argument('--decoder_update_freq', default=1, type=int)
    parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
    parser.add_argument('--num_layers', default=4, type=int)
    parser.add_argument('--num_filters', default=32, type=int)
    # sac
    parser.add_argument('--discount', default=0.99, type=float)
    parser.add_argument('--init_temperature', default=0.01, type=float)
    parser.add_argument('--alpha_lr', default=1e-3, type=float)
    parser.add_argument('--alpha_beta', default=0.9, type=float)
    parser.add_argument('--latent_dim', default=128, type=int)
    # misc
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--work_dir', default='./log', type=str)
    parser.add_argument('--save_tb', default=False, action='store_true')
    parser.add_argument('--save_model', default=False, action='store_true')
    parser.add_argument('--save_buffer', default=False, action='store_true')
    parser.add_argument('--save_video', default=False, action='store_true')
    parser.add_argument('--render', default=False, action='store_true')
    parser.add_argument('--detach_encoder', default=False, action='store_true')
    parser.add_argument('--transition_model_type', default='', type=str,
                        choices=['', 'deterministic', 'probabilistic', 'ensemble'])
    parser.add_argument('--port', default=2000, type=int)
    parser.add_argument('--trafficManagerPort', default=8000, type=int)
    # data augs
    parser.add_argument('--data_augs', default='crop', type=str)
    parser.add_argument('--log_interval', default=100, type=int)
    # fine_tuning
    parser.add_argument('--data_len', default=500, type=int)  # 微调的最大数据长度
    parser.add_argument('--z_score_value', default=1.0, type=float)  # 微调的最大数据长度
    parser.add_argument('--project_name', default="None", type=str)
    args = parser.parse_args()
    return args


def evaluate(env, args, agent, video, num_episodes, L, step, device=None, embed_viz_dir=None, do_carla_metrics=False,
             res_dir="", model_dir=""):
    # carla metrics:
    crash_intensity = 0.
    steer = 0.
    brake = 0.
    count = 0
    all_ep_rewards = []
    all_steers = []
    all_brakes = []
    reason_each_episode_ended = []
    distance_driven_each_episode = []
    all_crash_intensity = []
    agent.load_best(model_dir)
    video_step = 0
    for i in range(num_episodes):
        # carla metrics:
        dist_driven_this_episode = 0.

        obs, info = env.reset()
        pre_obs = obs
        pre_action = [0, 0]
        video.init(enabled=True)
        done = False
        episode_reward = 0
        while not done:
            with utils.eval_mode(agent):
                # obs: (9, 84, 84 * num_cameras)
                action = agent.select_action(obs, pre_obs, pre_action)

            next_obs, reward, done, info = env.step(action)
            pre_action = action
            pre_obs = obs
            obs = next_obs

            # metrics:
            if do_carla_metrics:
                dist_driven_this_episode += info['distance']
                crash_intensity += info['crash_intensity']
                steer += abs(info['steer'])
                brake += info['brake']
                count += 1

            video.record(env)
            episode_reward += reward
            video_step = video_step + 1

        # metrics:
        if do_carla_metrics:
            reason_each_episode_ended.append(info['reason_each_episode_ended'])
            distance_driven_each_episode.append(dist_driven_this_episode)
            all_steers.append(steer)
            all_brakes.append(brake)
            all_crash_intensity.append(crash_intensity)
            eval_log_filepath = os.path.join(res_dir, "{}.txt".format(args.seed))
            eval_log_txt_formatter = "{step},{distance_driven_each_episode},{steer},{brake}, {crash_intensity},{episode_reward},{reason}\n"
            to_write = eval_log_txt_formatter.format(step=step,
                                                     steer=steer,
                                                     brake=brake,
                                                     crash_intensity=crash_intensity,
                                                     distance_driven_each_episode=dist_driven_this_episode,
                                                     episode_reward=episode_reward,
                                                     reason=info['reason_each_episode_ended'])

            with open(eval_log_filepath, "a") as f:
                f.write(to_write)

        all_ep_rewards.append(episode_reward)
        video.save('%d.mp4' % step)
        video_step = 0

    mean_ep_reward = np.mean(all_ep_rewards)
    best_ep_reward = np.max(all_ep_rewards)
    std_ep_reward = np.std(all_ep_rewards)
    print("mean_ep_reward:", mean_ep_reward)
    print("std_ep_reward:", std_ep_reward)
    print("best_ep_reward:", best_ep_reward)
    print("mean_distance_driven_episode:", np.mean(distance_driven_each_episode))
    print("mean_steer_episode:", np.mean(all_steers))
    print("mean_brake_episode:", np.mean(all_brakes))
    print("mean_crash_intensity_episode:", np.mean(all_crash_intensity))

    L.log('eval/mean_ep_reward', mean_ep_reward, step)
    L.log('eval/std_ep_reward', std_ep_reward, step)
    L.log('eval/best_ep_reward', best_ep_reward, step)
    L.log('eval/mean_distance_driven_episode', np.mean(distance_driven_each_episode), step)
    L.log('eval/mean_steer_episode', np.mean(all_steers), step)
    L.log('eval/mean_brake_episode', np.mean(all_brakes), step)

    L.dump(step)

    # if do_carla_metrics:
    #     print('METRICS--------------------------')
    #     print("reason_each_episode_ended: {}".format(reason_each_episode_ended))
    #     print("distance_driven_each_episode: {}".format(distance_driven_each_episode))
    #     print('crash_intensity: {}'.format(crash_intensity / num_episodes))
    #     print('steer: {}'.format(steer / count))
    #     print('brake: {}'.format(brake / count))
    #     print('---------------------------------')


def make_agent(obs_shape, action_shape, args, device):
    # take deepmdp as base
    agent = MYAgent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        device=device,
        hidden_dim=args.hidden_dim,
        discount=args.discount,
        init_temperature=args.init_temperature,
        alpha_lr=args.alpha_lr,
        alpha_beta=args.alpha_beta,
        actor_lr=args.actor_lr,
        actor_beta=args.actor_beta,
        actor_log_std_min=args.actor_log_std_min,
        actor_log_std_max=args.actor_log_std_max,
        actor_update_freq=args.actor_update_freq,
        encoder_stride=args.encoder_stride,
        critic_lr=args.critic_lr,
        critic_beta=args.critic_beta,
        critic_tau=args.critic_tau,
        critic_target_update_freq=args.critic_target_update_freq,
        encoder_type=args.encoder_type,
        encoder_feature_dim=args.encoder_feature_dim,
        encoder_lr=args.encoder_lr,
        encoder_tau=args.encoder_tau,
        decoder_type=args.decoder_type,
        decoder_lr=args.decoder_lr,
        decoder_update_freq=args.decoder_update_freq,
        decoder_weight_lambda=args.decoder_weight_lambda,
        transition_model_type=args.transition_model_type,
        num_layers=args.num_layers,
        num_filters=args.num_filters,
        data_augs=args.data_augs
    )
    # resume
    if args.load_encoder:
        model_dict = agent.actor.encoder.state_dict()
        encoder_dict = torch.load(args.load_encoder)
        encoder_dict = {k[8:]: v for k, v in encoder_dict.items() if 'encoder.' in k}  # hack to remove encoder. string
        agent.actor.encoder.load_state_dict(encoder_dict)
        agent.critic.encoder.load_state_dict(encoder_dict)

    if args.load_critic_best == 1:
        print("load critic.................")
        print(args.critic_best_path)
        # "/home/xcm/LLM_RL/save/rl-only/carla_HighwayLimit_seed:2_new_2024-10-09-11-08-37/model"
        agent.load_best(args.critic_best_path)
    print("--------------------loss_type:", args.loss_type)

    return agent


def main():
    args = parse_args()
    utils.set_seed_everywhere(args.seed)
    if args.scenarios == "ghost_static" or args.scenarios == "shelter_car":
        max_episode_steps = 1000
    else:
        max_episode_steps = 1000
    env = CarlaEnv(
        render_display=args.render,  # for local debugging only
        display_text=args.render,  # for local debugging only
        changing_weather_speed=0.1,  # [0, +inf)
        rl_image_size=args.image_size,
        max_episode_steps=max_episode_steps,
        frame_skip=args.action_repeat,
        port=args.port,
        trafficManagerPort=args.trafficManagerPort,
        scenarios=args.scenarios,
        vlm_use=True,
        vrc_use=True
    )
    # TODO: implement env.seed(args.seed) ?
    eval_env = env

    # stack several consecutive frames together
    if args.encoder_type.startswith('pixel'):
        env = utils.FrameStack_Carla(env, k=args.frame_stack)
        eval_env = utils.FrameStack_Carla(eval_env, k=args.frame_stack)

    # add time
    work_dir = args.work_dir + '_' + "seed:{}_".format(args.seed) + time.strftime("%Y-%m-%d-%H-%M-%S")
    utils.make_dir(work_dir)
    video_dir = utils.make_dir(os.path.join(work_dir, 'video'))
    model_dir = utils.make_dir(os.path.join(work_dir, 'model'))
    buffer_dir = utils.make_dir(os.path.join(work_dir, 'buffer'))
    res_dir = utils.make_dir(os.path.join(work_dir, 'res_dir'))

    video = VideoRecorder(video_dir if args.save_video else None)

    with open(os.path.join(work_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lora_weight_path = "/home/user/cg/LM_F/save/carla_HighwayLimit_ft_10w_count_hard_no_used_info_loss_seed:1_Qwen-2B-E_2025-04-25-12-20-16/buffer/output/Qwen2-VL/checkpoint-31"
    VLM_agent = VLM_Agent(args.scenarios, device, res_dir, buffer_dir, args.seed, True, lora_weight_path)
    VLM_init = [0, 0.0]

    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    replay_buffer = utils.ReplayBuffer(
        obs_shape=env.observation_space.shape,
        action_shape=env.action_space.shape,
        capacity=args.replay_buffer_capacity,
        batch_size=args.batch_size,
        device=device,
        pre_image_size=args.image_size,
        use_loss=args.loss_type
    )

    agent = make_agent(
        obs_shape=env.observation_space.shape,
        action_shape=env.action_space.shape,
        args=args,
        device=device
    )
    L = Logger(work_dir, use_tb=args.save_tb)

    # train agent
    explore_times_step = 0
    current_option = 1
    use_option = False
    #
    explore_times = random.randint(50, 400)
    episode, episode_reward, done = 0, 0, True
    start_time = time.time()
    for step in range(args.num_train_steps):
        if done:
            if step > 0:
                L.log('train/duration', time.time() - start_time, step)
                start_time = time.time()
                L.dump(step)

            # evaluate agent periodically
            # if episode % 100 == 0 and episode != 0:
            #     evaluate(eval_env, args, agent, video, 1, L, step, do_carla_metrics=True,
            #              res_dir=res_dir, model_dir=model_dir)

            if args.save_model:
                agent.save_best(model_dir, episode_reward)

            L.log('train/episode_reward', episode_reward, step)

            train_log_filepath = os.path.join(res_dir, "train_log.txt")
            train_log_txt_formatter = "{step}:{episode_reward}\n"
            to_write = train_log_txt_formatter.format(step=step,
                                                      episode_reward=episode_reward)
            with open(train_log_filepath, "a") as f:
                f.write(to_write)

            # if step > args.init_steps:
            #     print('------------------episode_step:{}-------------------'.format(episode_step))
            obs, info = env.reset()
            pre_obs = obs
            pre_action = VLM_init
            done = False
            use_option = False
            explore_times_step = 0
            current_option = 1
            episode_reward = 0
            episode_step = 0
            episode += 1
            reward = 0
            explore_times = random.randint(50, 400)

            L.log('train/episode', episode, step)

        # sample action for data collection
        if step < args.init_steps:
            RL_action = agent.select_action(obs)
            with torch.no_grad():
                if step % args.vlm_freq == 0:
                    # print("------------------Warming Access VLM-------------------")
                    vlm_obs = [info['vlm_rgb'], info['selected_ego_velocity'], info['ego_orientation']]
                    VLM_action = VLM_agent.select_action(vlm_obs, step, args.seed)
                else:
                    VLM_action = VLM_action

        else:
            with utils.eval_mode(agent):
                RL_action = agent.select_action(obs)
            with torch.no_grad():
                if info is not None and (episode_step == 1 or episode_step % args.vlm_freq == 0):
                    # print("------------------Warming Access VLM-------------------")
                    vlm_obs = [info['vlm_rgb'], info['selected_ego_velocity'], info['ego_orientation']]
                    VLM_action = VLM_agent.select_action(vlm_obs, step, args.seed)
                else:
                    # keep VLM_action unchange
                    VLM_action = VLM_action

        if explore_times_step % explore_times == 0:
            current_option = 1 - current_option
        explore_times_step = explore_times_step + 1

        if current_option == 1:
            RL_action = VLM_action
        else:
            pass

        L.log('train/xx', xx, step)

        # run training update
        if step >= args.init_steps:
            num_updates = args.init_steps if step == args.init_steps else 1
            for _ in range(num_updates):
                agent.update(replay_buffer, L, step, Loss_type=args.loss_type, vlm_update_freq=args.vlm_freq)

        curr_reward = reward
        next_obs, reward, done, info = env.step(RL_action)

        # allow infinit bootstrap
        done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
            done
        )
        episode_reward += reward

        replay_buffer.add_pre(obs, RL_action, VLM_action, pre_action, curr_reward, reward, next_obs, done_bool, pre_obs)
        # update current obs
        pre_obs = obs
        pre_action = RL_action
        obs = next_obs
        episode_step += 1

    # evaluate agent
    agent.save(model_dir, step)
    # print('----------------start eval-----------------')
    # evaluate(eval_env, args, agent, video, args.num_eval_episodes, L, step, do_carla_metrics=True,
    #          res_dir=res_dir, model_dir=model_dir)


if __name__ == '__main__':
    main()
