# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
import argparse
import time
import json
import random
import csv
import dmc2gym
import utils
from logger import Logger
from video import VideoRecorder
import gym
from agent.vlm_agent_gym_and_mujoco import VLM_Agent
# from agent.my_agent_mujoco import MYAgent
from agent.bisim_agent import BisimAgent as MYAgent
import os
import warnings

warnings.filterwarnings('ignore')

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def parse_args():
    parser = argparse.ArgumentParser()
    # environment
    parser.add_argument('--domain_name', default='gym', choices=['gym', 'mujoco'])
    parser.add_argument('--task_name', default='CarRacing-v2')
    parser.add_argument('--pre_transform_image_size', default=100, type=int)
    parser.add_argument('--image_size', default=128, type=int)
    parser.add_argument('--action_repeat', default=1, type=int)
    parser.add_argument('--frame_stack', default=3, type=int)
    parser.add_argument('--resource_files', type=str)
    parser.add_argument('--eval_resource_files', type=str)
    parser.add_argument('--img_source', default=None, type=str, choices=['color', 'noise', 'images', 'video', 'none'])
    parser.add_argument('--total_frames', default=1000, type=int)
    parser.add_argument('--loss_type', default='value', type=str, choices=['value', 'loss', 'dist', 'anneal', 'random'])

    # replay buffer
    parser.add_argument('--replay_buffer_capacity', default=100000, type=int)
    # train
    parser.add_argument('--agent', default='deepmdp', type=str, choices=['baseline', 'llm', 'deepmdp'])
    parser.add_argument('--init_steps', default=1000, type=int)
    parser.add_argument('--num_train_steps', default=1000000, type=int)
    parser.add_argument('--batch_size', default=512, type=int)
    parser.add_argument('--hidden_dim', default=256, type=int)
    parser.add_argument('--k', default=3, type=int, help='number of steps for inverse model')
    parser.add_argument('--bisim_coef', default=0.5, type=float, help='coefficient for bisim terms')
    parser.add_argument('--load_encoder', default=None, type=str)
    parser.add_argument('--vlm_freq', default=60, type=int)
    parser.add_argument('--load_critic_best', default=1, type=int)
    parser.add_argument('--critic_best_path', default=None, type=str)

    # eval
    parser.add_argument('--eval_freq', default=10, type=int)  # TODO: master had 10000
    parser.add_argument('--num_eval_episodes', default=20, type=int)
    # critic
    parser.add_argument('--critic_lr', default=1e-3, type=float)
    parser.add_argument('--critic_beta', default=0.9, type=float)
    parser.add_argument('--critic_tau', default=0.005, type=float)
    parser.add_argument('--critic_target_update_freq', default=2, type=int)
    # actor
    parser.add_argument('--actor_lr', default=1e-3, type=float)
    parser.add_argument('--actor_beta', default=0.9, type=float)
    parser.add_argument('--actor_log_std_min', default=-10, type=float)
    parser.add_argument('--actor_log_std_max', default=2, type=float)
    parser.add_argument('--actor_update_freq', default=2, type=int)
    # encoder/decoder
    parser.add_argument('--encoder_type', default='pixel', type=str, choices=['pixel', 'pixelCarla', 'identity'])
    parser.add_argument('--encoder_feature_dim', default=50, type=int)
    parser.add_argument('--encoder_lr', default=1e-3, type=float)
    parser.add_argument('--encoder_tau', default=0.005, type=float)
    parser.add_argument('--encoder_stride', default=1, type=int)
    parser.add_argument('--decoder_type', default='pixel', type=str, choices=['pixel', 'identity'])
    parser.add_argument('--decoder_lr', default=1e-3, type=float)
    parser.add_argument('--decoder_update_freq', default=1, type=int)
    parser.add_argument('--decoder_weight_lambda', default=0.0, type=float)
    parser.add_argument('--num_layers', default=4, type=int)
    parser.add_argument('--num_filters', default=32, type=int)
    # sac
    parser.add_argument('--discount', default=0.99, type=float)
    parser.add_argument('--init_temperature', default=0.01, type=float)
    parser.add_argument('--alpha_lr', default=1e-3, type=float)
    parser.add_argument('--alpha_beta', default=0.9, type=float)
    parser.add_argument('--latent_dim', default=128, type=int)
    # misc
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--work_dir', default='./log', type=str)
    parser.add_argument('--save_tb', default=False, action='store_true')
    parser.add_argument('--save_model', default=False, action='store_true')
    parser.add_argument('--save_buffer', default=False, action='store_true')
    parser.add_argument('--save_video', default=False, action='store_true')
    parser.add_argument('--render', default=False, action='store_true')
    parser.add_argument('--detach_encoder', default=False, action='store_true')
    parser.add_argument('--transition_model_type', default='', type=str,
                        choices=['', 'deterministic', 'probabilistic', 'ensemble'])
    # data augs
    parser.add_argument('--data_augs', default='no_aug', type=str)
    parser.add_argument('--log_interval', default=100, type=int)
    args = parser.parse_args()
    return args


def evaluate(env, args, agent, video, num_episodes, L, step, device=None, embed_viz_dir=None, do_carla_metrics=False,
             res_dir="", model_dir="", final=False):
    # carla metrics:
    crash_intensity = 0.
    steer = 0.
    brake = 0.
    count = 0
    all_ep_rewards = []
    all_steers = []
    all_brakes = []
    reason_each_episode_ended = []
    distance_driven_each_episode = []
    all_crash_intensity = []
    agent.load_best(model_dir)
    video_step = 0
    for i in range(num_episodes):
        # carla metrics:
        dist_driven_this_episode = 0.

        obs = env.reset()
        pre_obs = obs
        pre_action = [0, 0]
        video.init(enabled=True)
        done = False
        episode_reward = 0
        while not done:
            with utils.eval_mode(agent):
                action = agent.select_action(obs)
            next_obs, reward, done, info = env.step(action)
            obs = next_obs

            video.record(env)
            episode_reward += reward
            video_step = video_step + 1

        eval_log_filepath = os.path.join(res_dir, "{}.txt".format(args.seed))
        eval_log_txt_formatter = "{step},{episode_reward}\n"
        to_write = eval_log_txt_formatter.format(step=step, episode_reward=episode_reward)

        with open(eval_log_filepath, "a") as f:
            f.write(to_write)

        all_ep_rewards.append(episode_reward)
        if final:
            video.save('%d.mp4' % i)
        else:
            video.save('%d.mp4' % step)
        video_step = 0

    mean_ep_reward = np.mean(all_ep_rewards)
    best_ep_reward = np.max(all_ep_rewards)
    std_ep_reward = np.std(all_ep_rewards)
    print("mean_ep_reward:", mean_ep_reward)
    print("std_ep_reward:", std_ep_reward)
    print("best_ep_reward:", best_ep_reward)

    L.log('eval/mean_ep_reward', mean_ep_reward, step)
    L.log('eval/std_ep_reward', std_ep_reward, step)
    L.log('eval/best_ep_reward', best_ep_reward, step)
    L.dump(step)


def append_to_csv(step, episode_reward, filename='episode_rewards.csv'):
    """
    将 step 和 episode_reward 追加写入到 CSV 文件中。

    参数:
        step (int): 当前的步数。
        episode_reward (float): 当前 episode 的奖励。
        filename (str): CSV 文件名，默认为 'episode_rewards.csv'。
    """
    # 检查文件是否存在，如果不存在则创建并写入表头
    if not os.path.exists(filename):
        with open(filename, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['Step', 'Value'])  # 写入表头

    # 以追加模式写入数据
    with open(filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([step, episode_reward])  # 写入数据


def make_agent(obs_shape, action_shape, args, device):
    # take deepmdp as base
    agent = MYAgent(
        obs_shape=obs_shape,
        action_shape=action_shape,
        device=device,
        hidden_dim=args.hidden_dim,
        discount=args.discount,
        init_temperature=args.init_temperature,
        alpha_lr=args.alpha_lr,
        alpha_beta=args.alpha_beta,
        actor_lr=args.actor_lr,
        actor_beta=args.actor_beta,
        actor_log_std_min=args.actor_log_std_min,
        actor_log_std_max=args.actor_log_std_max,
        actor_update_freq=args.actor_update_freq,
        encoder_stride=args.encoder_stride,
        critic_lr=args.critic_lr,
        critic_beta=args.critic_beta,
        critic_tau=args.critic_tau,
        critic_target_update_freq=args.critic_target_update_freq,
        encoder_type=args.encoder_type,
        encoder_feature_dim=args.encoder_feature_dim,
        encoder_lr=args.encoder_lr,
        encoder_tau=args.encoder_tau,
        decoder_type=args.decoder_type,
        decoder_lr=args.decoder_lr,
        decoder_update_freq=args.decoder_update_freq,
        decoder_weight_lambda=args.decoder_weight_lambda,
        transition_model_type=args.transition_model_type,
        num_layers=args.num_layers,
        num_filters=args.num_filters,
        data_augs=args.data_augs
    )
    # resume
    if args.load_encoder:
        model_dict = agent.actor.encoder.state_dict()
        encoder_dict = torch.load(args.load_encoder)
        encoder_dict = {k[8:]: v for k, v in encoder_dict.items() if 'encoder.' in k}  # hack to remove encoder. string
        agent.actor.encoder.load_state_dict(encoder_dict)
        agent.critic.encoder.load_state_dict(encoder_dict)

    if args.load_critic_best == 1:
        print("load critic.................")
        print(args.critic_best_path)
        agent.load_critic(args.critic_best_path)
    print("--------------------loss_type:", args.loss_type)

    return agent


def main():
    args = parse_args()
    utils.set_seed_everywhere(args.seed)

    pre_transform_image_size = args.pre_transform_image_size if 'crop' in args.data_augs else args.image_size
    pre_image_size = args.pre_transform_image_size  # record the pre transform image size for translation

    env = dmc2gym.make(
        domain_name='cartpole',
        task_name='swingup',  # swingup  balance
        resource_files=args.resource_files,
        img_source=args.img_source,
        total_frames=args.total_frames,
        seed=args.seed,
        visualize_reward=False,
        obs_mode='pixel',
        from_pixels=(args.encoder_type == 'pixel'),
        height=pre_transform_image_size,
        width=pre_transform_image_size,
        frame_skip=args.action_repeat
    )
    env.seed(args.seed)

    if args.encoder_type.startswith('pixel'):
        env = utils.FrameStack(env, k=args.frame_stack)

    # add time
    work_dir = args.work_dir + '_' + "seed:{}_".format(args.seed) + time.strftime("%Y-%m-%d-%H-%M-%S")
    utils.make_dir(work_dir)
    video_dir = utils.make_dir(os.path.join(work_dir, 'video'))
    model_dir = utils.make_dir(os.path.join(work_dir, 'model'))
    buffer_dir = utils.make_dir(os.path.join(work_dir, 'buffer'))
    res_dir = utils.make_dir(os.path.join(work_dir, 'res_dir'))

    video = VideoRecorder(video_dir if args.save_video else None)

    with open(os.path.join(work_dir, 'args.json'), 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if args.loss_type != "value":
        # VLM_agent = VLM_Agent(args.task_name, device, res_dir)
        VLM_agent = VLM_Agent("cartpole_swingup", device, res_dir)
    else:
        VLM_agent = None

    VLM_init = np.array([0])

    assert env.action_space.low.min() >= -1
    assert env.action_space.high.max() <= 1

    if args.encoder_type == 'pixel':
        obs_shape = (3 * args.frame_stack, args.image_size, args.image_size)
    else:
        obs_shape = env.observation_space.shape

    replay_buffer = utils.ReplayBuffer(
        obs_shape=obs_shape,
        action_shape=env.action_space.shape,
        capacity=args.replay_buffer_capacity,
        batch_size=args.batch_size,
        device=device,
        pre_image_size=args.image_size,
        use_loss=args.loss_type
    )

    agent = make_agent(
        obs_shape=obs_shape,
        action_shape=env.action_space.shape,
        args=args,
        device=device
    )
    L = Logger(work_dir, use_tb=args.save_tb)

    # train agent
    episode, episode_reward, done = 0, 0, True
    start_time = time.time()
    for step in range(args.num_train_steps):
        if done:
            if step > 0:
                L.log('train/duration', time.time() - start_time, step)
                start_time = time.time()
                L.dump(step)

            # evaluate agent periodically
            # if episode % 20 == 0 and episode != 0:
            #     evaluate(env, args, agent, video, 1, L, step, do_carla_metrics=True,
            #              res_dir=res_dir, model_dir=model_dir)

            if args.save_model:
                agent.save_best(model_dir, episode_reward)

            L.log('train/episode_reward', episode_reward, step)

            # train_log_filepath = os.path.join(res_dir, "train_log.txt")
            # train_log_txt_formatter = "{step}:{episode_reward}\n"
            # to_write = train_log_txt_formatter.format(step=step, episode_reward=episode_reward)
            # with open(train_log_filepath, "a") as f:
            #     f.write(to_write)

            filepath = os.path.join(res_dir, "{}.csv".format(args.seed))
            append_to_csv(step, episode_reward, filepath)

            # if step > args.init_steps:
            #     print('------------------episode_step:{}-------------------'.format(episode_step))
            obs = env.reset()
            pre_obs = obs
            pre_action = VLM_init
            done = False
            use_option = False
            explore_times_step = 0
            info = None
            VLM_action = VLM_init
            current_option = 0
            episode_reward = 0
            episode_step = 0
            episode += 1
            reward = 0
            L.log('train/episode', episode, step)

        # sample action for data collection
        if step < args.init_steps:
            RL_action = env.action_space.sample()  # random
            RL_action = np.array(RL_action)
            if args.loss_type != "value":
                with torch.no_grad():
                    if episode_step % args.vlm_freq == 0:
                        post_action_image = env.render(mode='rgb_array')
                        VLM_action = VLM_agent.select_action(post_action_image, episode_step, args.seed)
                        VLM_action = np.array([VLM_action])
                    elif episode_step == 0:
                        VLM_action = VLM_init
                    else:
                        VLM_action = VLM_action

            # if args.loss_type == "dist":
            #     RL_action = VLM_action
        else:
            if args.loss_type != "value":
                with torch.no_grad():
                    if episode_step !=0 and episode_step % args.vlm_freq == 0:
                        post_action_image = env.render(mode='rgb_array')
                        VLM_action = VLM_agent.select_action(post_action_image, episode_step, args.seed)
                        VLM_action = np.array([VLM_action])
                    else:
                        # keep VLM_action unchange
                        VLM_action = VLM_action

            with utils.eval_mode(agent):
                RL_action = agent.sample_action(obs)
                RL_action = np.array(RL_action)

        # run training update
        if step >= args.init_steps:
            num_updates = args.init_steps if step == args.init_steps else 1
            for _ in range(num_updates):
                agent.update(replay_buffer, L, step, Loss_type=args.loss_type, vlm_update_freq=args.vlm_freq)

        curr_reward = reward
        next_obs, reward, done, info = env.step(RL_action)

        # allow infinit bootstrap
        done_bool = 0 if episode_step + 1 == env._max_episode_steps else float(
            done
        )
        episode_reward += reward

        replay_buffer.add_raw(obs, RL_action, VLM_action, curr_reward, reward, next_obs, done_bool)
        # update current obs
        pre_obs = obs
        pre_action = RL_action
        obs = next_obs
        episode_step += 1

    # evaluate agent
    agent.save(model_dir, step)
    # print('----------------start eval-----------------')
    evaluate(env, args, agent, video, 20, L, step, do_carla_metrics=False,
             res_dir=res_dir, model_dir=model_dir, final=True)


if __name__ == '__main__':
    main()
