import argparse
import os
import torch

# PYTHONPATH
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from demo_collection.utils.utils import set_up_log_dirs, logging
from demo_collection.utils.wandb_logger import wandb_logger as Logger

from rlf.envs.widowx_interface import widowx_obs_transform

from pathlib import Path

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def add_args(parser):
    # init log dir
    current_dir = Path.cwd()
    log_path = str(current_dir)

    # for save_traj
    parser.add_argument('--wand', type=str2bool, default=True)
    parser.add_argument('--project_name', type=str, default="p-goal-prox")
    parser.add_argument('--prefix', type=str, default="agent_train")
    parser.add_argument('--traj_load_dir', type=str, default=None)
    parser.add_argument('--log_dir', type=str, default=os.path.join(log_path, "data", "log"))

    parser.add_argument("-s", "--seed", type=int, default=0,
                        help="Seed(s) for randomness.")

def get_default_args():
    parser = argparse.ArgumentParser()
    add_args(parser)
    args, rest = parser.parse_known_args()
    return args


def main():
    # get args
    args = get_default_args()
    # logger
    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs

    dir_path = args.traj_load_dir
    assert os.path.exists(dir_path), f"traj load dir {dir_path} does not exist"

    # List all .pt files
    # files = sorted([f for f in os.listdir(dir_path) if f.endswith(".pt") and os.path.isfile(os.path.join(dir_path, f))])
    # print(f"Found {len(files)} trajectory files: {files}")
    files = []
    for root, _, filenames in os.walk(dir_path):
        for f in filenames:
            if f.endswith(".pt"):
                files.append(os.path.join(root, f))

    files = sorted(files)
    print(f"Found {len(files)} trajectory files:")
    for f in files:
        print(f" - {f}")

    # Initialize the merged buffer
    merged_weights = {
        'obs': [],
        'next_obs': [],
        'done': [],
        'actions': [],
        'ep_found_goal': []
    }

    # Load all trajs and append their tensors
    for file in files:
        traj = torch.load(file)

        for key in merged_weights:
            merged_weights[key].append(traj[key])

    # Concatenate all parts
    for key in merged_weights:
        merged_weights[key] = torch.cat(merged_weights[key], dim=0)

        # # transform obs and next_obs
        # if key in ['obs', 'next_obs']:
        #     merged_weights[key] = widowx_obs_transform(merged_weights[key])

    # Save merged result
    save_path = os.path.join(reward_save_dir, "merged_trajs.pt")
    torch.save(merged_weights, save_path)
    print(f"[✓] Merged trajectories saved at: {save_path}")
 
if __name__ == "__main__":
    main()
    