import argparse
import os
import numpy as np
import torch

# PYTHONPATH
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from demo_collection.utils.utils import set_up_log_dirs, logging
from demo_collection.utils.wandb_logger import wandb_logger as Logger

from rlf.envs.widowx_interface import widowx_obs_transform

from pathlib import Path
import matplotlib.pyplot as plt

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def add_args(parser):
    # init log dir
    current_dir = Path.cwd()
    log_path = str(current_dir)

    # for save_traj
    parser.add_argument('--wand', type=str2bool, default=True)
    parser.add_argument('--project_name', type=str, default="p-goal-prox")
    parser.add_argument('--prefix', type=str, default="agent_train")
    parser.add_argument('--traj_load_dir', type=str, default=None)
    parser.add_argument('--log_dir', type=str, default=os.path.join(log_path, "data", "log"))

    parser.add_argument("-s", "--seed", type=int, default=0,
                        help="Seed(s) for randomness.")

def get_default_args():
    parser = argparse.ArgumentParser()
    add_args(parser)
    args, rest = parser.parse_known_args()
    return args


def main():
    # get args
    args = get_default_args()
    # logger
    logger = Logger(args)
    logdirs = set_up_log_dirs(args, logger.prefix)
    log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir = logdirs

    dir_path = args.traj_load_dir
    assert os.path.exists(dir_path), f"traj load dir {dir_path} does not exist"

    # List all .pt files
    # files = sorted([f for f in os.listdir(dir_path) if f.endswith(".pt") and os.path.isfile(os.path.join(dir_path, f))])
    # print(f"Found {len(files)} trajectory files: {files}")
    files = []
    for root, _, filenames in os.walk(dir_path):
        for f in filenames:
            if f.endswith(".pt"):
                files.append(os.path.join(root, f))

    files = sorted(files)
    print(f"Found {len(files)} trajectory files:")
    for f in files:
        print(f" - {f}")

    # Initialize the merged buffer
    merged_weights = {
        'obs': [],
        'next_obs': [],
        'done': [],
        'actions': [],
        'ep_found_goal': []
    }

    # Load all trajs and append their tensors
    for file in files:
        traj = torch.load(file)

        for key in merged_weights:
            merged_weights[key].append(traj[key])

    # Concatenate all parts
    for key in merged_weights:
        merged_weights[key] = torch.cat(merged_weights[key], dim=0)


        # transform obs and next_obs
        if key in ['obs', 'next_obs']:
            obs = merged_weights[key]
            left_finger_pos = obs[:, 22:25]
            right_finger_pos = obs[:, 25:28]

            distances = np.linalg.norm(left_finger_pos - right_finger_pos, axis=1)

            min_val = 0.053
            max_val = 0.059
            is_grasped = (distances >= min_val) & (distances <= max_val)
            is_grasped = torch.tensor(is_grasped, dtype=torch.float32, device=obs.device)
            obs = torch.cat([obs, is_grasped[:, None]], dim=1)

            merged_weights[key] = widowx_obs_transform(obs)


    for key in merged_weights:
        if key in ['obs', 'next_obs']:
            # raise error if done=True but is_grasped=False
            done = merged_weights['done']
            is_grasped = merged_weights[key][:, -1].numpy().astype(bool)
            inconsistent_indices = np.where((done == 1) & (is_grasped == False))[0]
            if len(inconsistent_indices) > 0:
                raise ValueError(f"Inconsistent done and is_grasped at indices: {inconsistent_indices}")

    # Save merged result
    save_path = os.path.join(reward_save_dir, "merged_trajs.pt")
    torch.save(merged_weights, save_path)
    print(f"[✓] Merged trajectories saved at: {save_path}")



    # Plot histogram
    # plt.figure(figsize=(8, 5))
    # plt.hist(distances, bins=100, edgecolor='black')
    # plt.xlabel("Distance")
    # plt.ylabel("Number of Points")
    # plt.title("Distribution of Distances between A[i] and B[i]")
    # plt.grid(True)
    # plt.tight_layout()
    # plt.show()

        # # transform obs and next_obs
        # if key in ['obs', 'next_obs']:
        #     merged_weights[key] = widowx_obs_transform(merged_weights[key])

    # Save merged result
    # save_path = os.path.join(reward_save_dir, "merged_trajs.pt")
    # torch.save(merged_weights, save_path)
    # print(f"[✓] Merged trajectories saved at: {save_path}")
 
if __name__ == "__main__":
    main()
    