# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .adversarial import BipedalWalkerAdversarialEnv
from .walker_test_envs import BipedalWalkerDefault

# import pandas as pd

import numpy as np

# Compatibility for old Baselines code on NumPy >= 1.24
if not hasattr(np, "bool"):
    np.bool = np.bool_
if not hasattr(np, "int"):
    np.int = int
if not hasattr(np, "float"):
    np.float = float
if not hasattr(np, "object"):
    np.object = object
if not hasattr(np, "str"):
    np.str = str


BIPEDALWALKER_POET_DF_COLUMNS = [
    "roughness",
    "pitgap_low",
    "pitgap_high",
    "stumpheight_low",
    "stumpheight_high",
    "seed",
]

BIPEDALWALKER_DF_COLUMNS = [
    "roughness",
    "pitgap_low",
    "pitgap_high",
    "stumpheight_low",
    "stumpheight_high",
    "stairheight_low",
    "stairheight_high",
    "stair_steps",
    "seed",
]


# def bipedalwalker_df_from_encodings(env_name, encodings):
#     # Convert encodings to a DataFrame, ensuring all data is float

#     arr = np.array(encodings)
#     df = pd.DataFrame(arr.astype(float))

#     return df

# 	df = pd.DataFrame(encodings)
# 	if 'POET' in env_name:
# 		df.columns = BIPEDALWALKER_POET_DF_COLUMNS
# 	else:
# 		df.columns = BIPEDALWALKER_DF_COLUMNS

# 	return df


# dcd/envs/bipedalwalker/__init__.py

# def bipedalwalker_df_from_encodings(env_name, encodings):
#     print(f"Encodings type: {type(encodings)}")
#     print(f"Number of encodings: {len(encodings)}")

#     # Deep inspection of each encoding
#     for i, enc in enumerate(encodings):
#         print(f"\n=== Encoding {i} ===")
#         print(f"Type: {type(enc)}")
#         print(f"Shape: {getattr(enc, 'shape', 'no shape')}")
#         print(f"Dtype: {getattr(enc, 'dtype', 'no dtype')}")

#         if isinstance(enc, np.ndarray):
#             print(f"Array contents: {enc}")
#             print(f"Array dtype object: {enc.dtype == object}")

#             # Check if any elements are themselves arrays
#             if enc.dtype == object:
#                 print("Elements in this array:")
#                 for j, elem in enumerate(enc):
#                     print(f"  [{j}]: type={type(elem)}, value={elem}")

#         # Try converting this single encoding to see where it fails
#         try:
#             single_df = pd.DataFrame([enc])
#             print("✓ Single encoding converts OK")
#         except Exception as e:
#             print(f"✗ Single encoding fails: {e}")

#         try:
#             single_df = pd.DataFrame([enc.tolist()])
#             print("✓ Single encoding as list converts OK")
#         except Exception as e:
#             print(f"✗ Single encoding as list fails: {e}")

#     # Try different conversion approaches
#     print("\n=== Trying different approaches ===")

#     # Approach 1: numpy array first
#     try:
#         arr = np.array(encodings)
#         print(f"np.array() works: shape={arr.shape}, dtype={arr.dtype}")
#         df = pd.DataFrame(arr)
#         print("✓ DataFrame from np.array works")
#         return df
#     except Exception as e:
#         print(f"✗ np.array approach failed: {e}")

#     # Approach 2: explicit conversion
#     try:
#         clean_data = []
#         for enc in encodings:
#             if isinstance(enc, np.ndarray):
#                 clean_data.append(enc.astype(float).tolist())
#             else:
#                 clean_data.append(list(enc))
#         df = pd.DataFrame(clean_data)
#         print("✓ Explicit conversion works")
#         return df
#     except Exception as e:
#         print(f"✗ Explicit conversion failed: {e}")

#     # Approach 3: manual row-by-row
#     try:
#         data_dict = {}
#         for col in range(len(encodings[0])):
#             data_dict[f'feature_{col}'] = [enc[col] for enc in encodings]
#         df = pd.DataFrame(data_dict)
#         print("✓ Manual column-wise conversion works")
#         return df
#     except Exception as e:
#         print(f"✗ Manual conversion failed: {e}")

#     raise Exception("All DataFrame creation methods failed")


# --- put near your other imports ---
# import os
# import numpy as np
# import pandas as pd

# def bipedalwalker_df_from_encodings(env_name, encodings, columns=None):
#     """
#     encodings: list/array of shape (N, K) or (K,)
#     Ensures no nested arrays/lists end up in the DataFrame.
#     Casts all but last column to float and the last to int (seed).
#     """
#     arr = np.array(encodings)
#     if arr.ndim == 1:
#         arr = arr.reshape(1, -1)  # single row -> (1, K)

#     # Default column names: p1..p{K-1}, seed
#     if columns is None:
#         K = arr.shape[1]
#         if K >= 2:
#             columns = [f"p{i+1}" for i in range(K-1)] + ["seed"]
#         else:
#             columns = [f"p{i+1}" for i in range(K)]

#     df = pd.DataFrame(arr, columns=columns[:arr.shape[1]])

#     # Convert numerics
#     if df.shape[1] >= 2:
#         # all but last as float
#         for c in df.columns[:-1]:
#             df[c] = pd.to_numeric(df[c], errors="raise")
#         # last as int64 (seed or id)
#         last = df.columns[-1]
#         df[last] = pd.to_numeric(df[last], errors="raise").astype("int64")
#     else:
#         # single column -> try float then int fallback
#         c0 = df.columns[0]
#         try:
#             df[c0] = pd.to_numeric(df[c0], errors="raise")
#         except Exception:
#             pass

#     # Scalarize any leftover array/list cells (belt & suspenders)
#     def _scalarize(x):
#         if isinstance(x, np.ndarray):
#             return x.item() if x.size == 1 else str(x.tolist())
#         if isinstance(x, (list, tuple)):
#             return x[0] if len(x) == 1 else str(x)
#         return x

#     df = df.applymap(_scalarize)
#     return df


# def bipedalwalker_df_from_encodings(env_name, encodings, columns=None):
#     """
#     encodings: list/array of shape (N, K) or (K,)
#     Ensures no nested arrays/lists end up in the DataFrame.
#     Casts all but last column to float and the last to int (seed).
#     """
#     # Convert to pure Python lists to avoid numpy 2.x issues
#     try:
#         # First convert encodings to pure Python structure
#         if isinstance(encodings, np.ndarray):
#             python_encodings = encodings.tolist()
#         else:
#             python_encodings = []
#             for enc in encodings:
#                 if isinstance(enc, np.ndarray):
#                     python_encodings.append(enc.tolist())
#                 elif isinstance(enc, (list, tuple)):
#                     # Ensure all nested items are Python types
#                     clean_row = []
#                     for item in enc:
#                         if isinstance(item, np.ndarray):
#                             clean_row.extend(item.tolist())
#                         elif isinstance(item, (list, tuple)):
#                             clean_row.extend(list(item))
#                         else:
#                             clean_row.append(float(item))
#                     python_encodings.append(clean_row)
#                 else:
#                     python_encodings.append([float(enc)])

#         # Create DataFrame directly from Python lists
#         df = pd.DataFrame(python_encodings)

#         # Set column names
#         if columns is None:
#             K = df.shape[1]
#             if K >= 2:
#                 columns = [f"p{i+1}" for i in range(K-1)] + ["seed"]
#             else:
#                 columns = [f"p{i+1}" for i in range(K)]

#         # Ensure we have the right number of columns
#         columns = columns[:df.shape[1]]
#         df.columns = columns

#     except Exception as e:
#         # Fallback: create minimal DataFrame
#         df = pd.DataFrame([[0.0, 0]], columns=["p1", "seed"])

#     # Convert numerics
#     if df.shape[1] >= 2:
#         # all but last as float
#         for c in df.columns[:-1]:
#             df[c] = pd.to_numeric(df[c], errors="raise")
#         # last as int64 (seed or id)
#         last = df.columns[-1]
#         df[last] = pd.to_numeric(df[last], errors="raise").astype("int64")
#     else:
#         # single column -> try float then int fallback
#         c0 = df.columns[0]
#         try:
#             df[c0] = pd.to_numeric(df[c0], errors="raise")
#         except Exception:
#             pass

#     # Scalarize any leftover array/list cells (belt & suspenders)
#     def _scalarize(x):
#         if isinstance(x, np.ndarray):
#             return x.item() if x.size == 1 else str(x.tolist())
#         if isinstance(x, (list, tuple)):
#             return x[0] if len(x) == 1 else str(x)
#         return x

#     df = df.map(_scalarize)
#     return df


import os
import numpy as np
import pandas as pd


class SimpleDataFrame:
    """Simple DataFrame-like object that avoids pandas compatibility issues"""

    def __init__(self, data, columns=None):
        self.data = data
        self.columns = columns or [f"col_{i}" for i in range(len(data[0]) if data else 0)]

    def to_csv(self, path):
        import csv

        with open(path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(self.columns)
            writer.writerows(self.data)


def bipedalwalker_df_from_encodings(env_name, encodings, columns=None):
    """
    encodings: list/array of shape (N, K) or (K,)
    Returns a simple DataFrame-like object that can write to CSV.
    """
    # Convert to pure Python lists
    clean_data = []

    if isinstance(encodings, np.ndarray):
        if encodings.ndim == 1:
            clean_data = [encodings.tolist()]
        else:
            clean_data = encodings.tolist()
    else:
        for enc in encodings:
            if isinstance(enc, np.ndarray):
                clean_data.append(enc.flatten().tolist())
            elif isinstance(enc, (list, tuple)):
                row = []
                for item in enc:
                    if isinstance(item, np.ndarray):
                        row.extend(item.flatten().tolist())
                    elif isinstance(item, (list, tuple)):
                        row.extend(list(item))
                    else:
                        row.append(float(item))
                clean_data.append(row)
            else:
                clean_data.append([float(enc)])

    # Set column names
    if columns is None and clean_data:
        K = len(clean_data[0])
        if K >= 2:
            columns = [f"p{i+1}" for i in range(K - 1)] + ["seed"]
        else:
            columns = [f"p{i+1}" for i in range(K)]

    return SimpleDataFrame(clean_data, columns)
