import argparse
import datetime
import os
import random
import sys
import yaml
import numpy as np
import torch as th
import imageio

from crafter.env import EnvWithDirection
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from achievement_distillation.algorithm import *
from achievement_distillation.constant import TASKS
from achievement_distillation.logger import Logger
from achievement_distillation.model import *
from achievement_distillation.sample import sample_rollouts, evaluate
from achievement_distillation.storage import RolloutStorage
from achievement_distillation.wrapper import VecPyTorch
from functools import partial
import json
from utils import *



def main(args):
    # Load config file
    with open(f"configs/{args.exp_name}.yaml", "r") as config_file:
        config = yaml.load(config_file, Loader=yaml.FullLoader)

    # Fix random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    th.manual_seed(args.seed)
    th.cuda.manual_seed_all(args.seed)
    th.backends.cudnn.benchmark = False
    current_task = args.task

    # CUDA setting
    th.set_num_threads(1)
    cuda = th.cuda.is_available()
    device = th.device("cuda:0" if cuda else "cpu")

    # Get current time for run name
    current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # Create logger
    group_name = f"{args.exp_name}-{args.timestamp}"
    run_name = f"{group_name}-s{args.seed:02}-{current_task}-{current_time}"

    if args.log_stats:
        log_dir = os.path.join("./logs", run_name)
        os.makedirs(log_dir, exist_ok=True)
        log_path = os.path.join(log_dir, "stats.jsonl")
        log_file = open(log_path, "w")
        logger = Logger(config=config, group=group_name, name=run_name)

    if args.save_ckpt:
        ckpt_dir = os.path.join("./models", run_name)
        os.makedirs(ckpt_dir, exist_ok=True)

    # Create environment using DummyVecEnv
    seeds = np.random.randint(0, 2**31 - 1, size=config["nproc"])
    env_fns = [partial(EnvWithDirection, seed=seed) for seed in seeds]
    venv = DummyVecEnv(env_fns)
    venv = VecPyTorch(venv, device=device)
    obs = venv.reset()

    # Create storage
    storage = RolloutStorage(
        nstep=config["nstep"],
        nproc=config["nproc"],
        observation_space=venv.observation_space,
        action_space=venv.action_space,
        hidsize=config["model_kwargs"]["hidsize"],
        device=device,
    )
    storage.obs[0].copy_(obs)
    
    storage_eva = RolloutStorage(
        nstep=config["nstep"],
        nproc=config["nproc"],
        observation_space=venv.observation_space,
        action_space=venv.action_space,
        hidsize=config["model_kwargs"]["hidsize"],
        device=device,
    )
    storage_eva.obs[0].copy_(obs)

    # Create model
    model_cls = getattr(sys.modules[__name__], config["model_cls"])
    model: BaseModel = model_cls(
        observation_space=venv.observation_space,
        action_space=venv.action_space,
        **config["model_kwargs"],
    )
    model = model.to(device)
    print(model)

    # Create algorithm
    algorithm_cls = getattr(sys.modules[__name__], config["algorithm_cls"])
    algorithm: BaseAlgorithm = algorithm_cls(
        model=model,
        **config["algorithm_kwargs"],
    )

    # Set lambda_t (fixed or manually adjustable)
    lambda_t = config["lambda_t"]
    lambda_decay = config["lambda_decay"]

    # Total number of training games
    total_games = config["nepoch"]
    imitation_threshold = int(total_games * 0.2)

    # Create the directory for saving videos if it doesn't exist
    video_dir = os.path.join("./videos", run_name)
    os.makedirs(video_dir, exist_ok=True)

    # Initialize video writers for each environment (only if save_video is True)
    video_writers = None

    if args.save_video:
        video_writers = []
        for env_id in range(config["nproc"]):
            video_path = os.path.join(video_dir, f"env_{env_id}.mp4")
            writer = imageio.get_writer(video_path, fps=30)
            video_writers.append(writer)

    # Run algorithm
    total_successes = np.zeros((0, len(TASKS)), dtype=np.int32)

    for epoch in range(1, total_games + 1):
        imitation_phase = epoch <= imitation_threshold
        #imitation_phase = False

        # Sample episodes
        rollout_stats = sample_rollouts(
            venv, model, storage, lambda_t, lambda_decay, imitation_phase, task=args.task,
        )

        # Compute returns
        storage.compute_returns(config["gamma"], config["gae_lambda"])

        # Update models
        train_stats = algorithm.update(storage, imitation_phase)

        # Reset storage
        storage.reset()

        # Compute score
        successes = rollout_stats["successes"]
        total_successes = np.concatenate([total_successes, successes], axis=0)
        #successes = 100 * np.mean(successes, axis=0)
        success_rate = 100 * np.mean(total_successes, axis=0)
        score = np.exp(np.mean(np.log(1 + success_rate))) - 1
        
        # Log the total tokens consumed during training
        total_tokens_used = rollout_stats["total_tokens"]

        train_stats_log = {
            "train_success_rate": {k: v for k, v in zip(TASKS, success_rate)},
            "train_score": score,
            "train_total_tokens": total_tokens_used,  # Log tokens used
        }

        # Evaluate the model every epoch
        if epoch % 5 == 0:
            total_successes_eval = np.zeros((0, len(TASKS)), dtype=np.int32)
            eval_results = evaluate(venv, model, storage_eva, task=args.task, video_writers=video_writers)
            storage_eva.reset()
            
            successes_eval = eval_results["successes"]
            total_successes_eval = np.concatenate([total_successes_eval, successes_eval], axis=0)
            #successes_eval = 100 * np.mean(successes_eval, axis=0)
            success_rate_eval = 100 * np.mean(total_successes_eval, axis=0)
            score_eval = np.exp(np.mean(np.log(1 + success_rate_eval))) - 1
            
            # Get eval stats
            eval_stats_log = {
                "eval_success_rate": {k: v for k, v in zip(TASKS, success_rate_eval)},
                "eval_score": score_eval,
            }
            
            # Log evaluation stats to W&B
            if args.log_stats:
                # JSON
                episode_lengths = eval_results["episode_lengths"]
                episode_rewards = eval_results["episode_rewards"]
                achievements = eval_results["achievements"]

                for i in range(len(episode_lengths)):
                    eval_stat = {
                        "eval_length": int(episode_lengths[i]),
                        "eval_reward": round(float(episode_rewards[i]), 1),
                    }
                    for j, task in enumerate(TASKS):
                        eval_stat[f"eval_achievement_{task}"] = int(achievements[i, j])

                    log_file.write(json.dumps(eval_stat) + "\n")
                    log_file.flush()

                # W&B
                logger.log(eval_stats_log, epoch)
                print("evaluation result:")
                print(json.dumps(eval_stats_log, indent=2))

        # print(f"\nepoch {epoch}:")
        # print(json.dumps(train_stats, indent=2))
        # print(json.dumps(train_stats_log, indent=2))

        if args.log_stats:
            episode_lengths = rollout_stats["episode_lengths"]
            episode_rewards = rollout_stats["episode_rewards"]
            achievements = rollout_stats["achievements"]
            total_tokens_used = rollout_stats["total_tokens"] 

            for i in range(len(episode_lengths)):
                rollout_stat = {
                    "train_length": int(episode_lengths[i]),
                    "train_reward": round(float(episode_rewards[i]), 1),
                    "train_tokens": int(total_tokens_used) 
                }
                for j, task in enumerate(TASKS):
                    rollout_stat[f"train_achievement_{task}"] = int(achievements[i, j])

                log_file.write(json.dumps(rollout_stat) + "\n")
                log_file.flush()

            logger.log(train_stats_log, epoch)

        if args.save_ckpt and epoch % config["save_freq"] == 0:
            ckpt_path = os.path.join(ckpt_dir, f"agent-e{epoch:03}.pt")
            th.save(model.state_dict(), ckpt_path)

    # Close all video writers if save_video is True
    if args.save_video:
        for writer in video_writers:
            writer.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_name", type=str, required=True)
    parser.add_argument("--timestamp", type=str, default="debug")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--log_stats", action="store_true")
    parser.add_argument("--save_ckpt", action="store_true")
    parser.add_argument("--task", type=str, default="make_stone_pickaxe", help='[make_stone_pickaxe, make_wood_pickaxe]')
    parser.add_argument("--save_video", default=False,  action="store_true", help="Whether to save video of the environment")

    args = parser.parse_args()

    main(args)
