import os
import pandas as pd
import shutil
import logging
import argparse
import time
import ast
import json
from typing import List, Dict, Any
import datetime
from datetime import timezone, timedelta
from transformers import AutoTokenizer
from openai import OpenAI


def create_prompt(observation_action_pairs, achievement_change):
    """
    Create a prompt for GPT-4o to generate an instruction for the achievement_change.
    """
    system_prompt = """You are a language model trained to analyze agent behavior in the game Crafter.
Your task is to infer the most likely instruction the agent was pursuing, given a sequence of environmental observations and actions.

Guidelines:
- Pay special attention to the most recent observation and action, as they reveal the agent's immediate intention.
- The agent can only interact with the tile it is directly facing, so consider only the facing tile when interpreting interaction actions.
- The do action means the agent is trying to interact with the tile it is facing. For example:
  - If facing material: collect material
  - If facing grass: collect sapling
  - If facing water: drink to restore thirst
  - If facing hostile creature: defeat the creature
  - If facing cow: hunt to restore hunger
- If there's a table or furnace nearby and your action starts with 'make', you're making a tool. Focus on that action.
- Avoid vague or generic explanations. Be precise and grounded in the recent context.

Your output should clearly state the inferred goal the agent was pursuing, based strictly on its behavior and what it was facing.
Keep your response very brief - around 10 words maximum.
"""
    user_prompt = """Here is a sequence of actions and current observation-action pair the agent took in the Crafter game.
The turns are listed in chronological order, from oldest to most recent.

"""

    for j, pair in enumerate(observation_action_pairs):
        obs = pair["observation"]
        act = pair["action"]
        user_prompt += f"### Turn {j+1}:\n### Observation\n{obs}\n### Action\n{act}\n"

    prompt = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    return prompt


def get_response_openai(prompt, model_name="gpt-4o"):
    """
    Get a response from GPT-4o using the OpenAI API and check token counts.

    Args:
        prompt: List of dictionaries with 'role' and 'content' keys

    Returns:
        The model's response text
    """
    # Check if OpenAI API key is set
    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise ValueError(
            "OPENAI_API_KEY environment variable is not set. Please set your OpenAI API key."
        )

    try:
        client = OpenAI(api_key=api_key)

        response = client.chat.completions.create(
            model=model_name, messages=prompt, temperature=0, max_tokens=256
        )

        # Print input tokens, output tokens, and total token count
        print(f"Input tokens: {response.usage.prompt_tokens}")
        print(f"Output tokens: {response.usage.completion_tokens}")
        print(f"Total tokens used: {response.usage.total_tokens}")

        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error calling OpenAI API: {e}")
        return "noop"  # Default to noop if API call fails


def check_is_defeat(row, next_row):
    # Extract the word following "You are facing" in the current row's history
    current_facing = extract_facing_entity(str(row["observation_goal"]))
    # Extract the word following "You are facing" in the next row's history
    next_facing = extract_facing_entity(str(next_row["observation_goal"]))

    # Check if the current facing entity is "zombie" or "skeleton"
    # Return True if the current action_goal is "do", the facing entity is "zombie" or "skeleton",
    # and the facing entity in the next row is different
    if (
        row["action_goal"] == "do"
        and current_facing in ["zombie", "skeleton"]
        and current_facing != next_facing
    ):
        return current_facing, next_facing, True
    return current_facing, next_facing, False


def extract_facing_entity(history):
    # Extract the first word following "You are facing"
    if "You are facing" in history:
        start = history.index("You are facing ") + len("You are facing ")
        end = history.find(" ", start)
        return history[start:end].strip()
    return None


def process_csv_by_achievement_change(
    csv_file_path: str, max_chunk_length: int = None, logger=None
):
    """
    Processes a CSV file by combining observation_goal and action_goal data until a non-empty achievement_change is encountered.
    When a non-empty achievement_change is found, that row is included in the current chunk before starting a new chunk.
    Returns a list of combined observation-action chunks with the achievement_change value that ended the chunk.

    Args:
        csv_file_path (str): Path to the CSV file
        max_chunk_length (int, optional): Maximum number of observation-action pairs per chunk.
                                         If None, there is no limit. Defaults to None.
    """
    # Use provided logger or get the root logger
    if logger is None:
        logger = logging.getLogger()

    # Load the CSV file
    df = pd.read_csv(csv_file_path)

    # Ensure required columns exist
    required_columns = ["observation_goal", "action_goal", "achievement_change", "step"]
    for col in required_columns:
        if col not in df.columns:
            raise ValueError(f"CSV file must contain a '{col}' column")

    chunks = []
    current_chunk = []
    chunk_start_idx = 0
    prev_step = -1

    # Process each row in the dataframe
    for idx in range(len(df) - 1):
        row = df.iloc[idx]
        next_row = df.iloc[idx + 1]
        is_defeat = False

        # If step decreases, it's a new trajectory
        current_step = row["step"]
        if prev_step > current_step:
            current_chunk = []
        prev_step = current_step

        # Add current observation and action to the chunk as a pair
        if pd.notna(row["observation_goal"]):
            action = row["action_goal"] if pd.notna(row["action_goal"]) else "noop"
            current_chunk.append(
                {"observation": row["observation_goal"], "action": action}
            )

        current_facing, next_facing, is_defeat = check_is_defeat(row, next_row)
        if is_defeat:
            logger.info(f"###### Defeat detected at step {row['step']}")

        logger.info(f"Current facing: {current_facing}, Next facing: {next_facing}")

        # Check if achievement_change is not empty
        if (
            len(ast.literal_eval(row["status_change"])) or is_defeat
        ):  # or row['achievement_change']
            # Add this achievement change to the current chunk before saving it
            achievement_change = row["achievement_change"]

            # If we have observations in the current chunk, save it
            if current_chunk:
                chunks.append(
                    {
                        "observation_action_pairs": (
                            current_chunk.copy()
                            if max_chunk_length is None
                            or len(current_chunk) <= max_chunk_length
                            else current_chunk[
                                len(current_chunk) - max_chunk_length :
                            ].copy()
                        ),
                        "start_idx": (
                            chunk_start_idx + 1
                            if max_chunk_length is None
                            else idx + 1 - max_chunk_length
                        ),
                        "end_idx": idx
                        + 1,  # inclusive of the current row with achievement_change
                        "achievement_change": achievement_change,
                    }
                )

                # Start a new chunk with the next row
                current_chunk = []
            chunk_start_idx = idx

    return chunks


def copy_csv_to_output_folder(csv_file_path, output_dir):
    """
    Copies the CSV file to the output directory and returns the new file path.
    """
    # Get logger
    logger = logging.getLogger()

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Get the filename from the path
    filename = os.path.basename(csv_file_path)

    # Create the new file path
    new_file_path = os.path.join(output_dir, f"with_instructions_{filename}")

    # Copy the file
    shutil.copy2(csv_file_path, new_file_path)

    logger.info(f"CSV file copied to: {new_file_path}")
    return new_file_path


def process_observations_and_update_csv(
    csv_file_path: str,
    output_dir="goal/output",
    logger=None,
    max_chunk_length: int = None,
):
    """
    Processes observations and actions in chunks based on achievement_change column
    and updates the CSV file with instructions using achievement_change as the instruction.

    Args:
        csv_file_path (str): Path to the CSV file
        output_dir (str, optional): Directory to save the output. Defaults to "goal/output".
        logger (logging.Logger, optional): Logger for logging messages. Defaults to None.
        max_chunk_length (int, optional): Maximum number of observation-action pairs per chunk.
                                         If None, there is no limit. Defaults to None.
    """
    # Use provided logger or get the root logger
    if logger is None:
        logger = logging.getLogger()

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Copy the CSV file to the output directory
    new_csv_path = copy_csv_to_output_folder(csv_file_path, output_dir)

    # Get CSV filename for logging
    csv_filename = os.path.basename(csv_file_path)

    logger.info(f"Starting processing")

    # Create chunks of observations based on achievement_change
    chunks = process_csv_by_achievement_change(csv_file_path, max_chunk_length, logger)
    logger.info(f"Found {len(chunks)} chunks to process")

    # Load the CSV file for updating
    df = pd.read_csv(new_csv_path)

    # Add the instruction column if it doesn't exist
    if "instruction" not in df.columns:
        df["instruction"] = ""

    # Track chunk lengths for average calculation
    chunk_lengths = []

    # Process each chunk
    for i, chunk_info in enumerate(chunks):
        logger.info(f"Processing chunk {i+1}/{len(chunks)}...")

        # Get chunk information
        observation_action_pairs = chunk_info["observation_action_pairs"]
        start_idx = chunk_info["start_idx"]
        end_idx = chunk_info["end_idx"]
        achievement_change = chunk_info["achievement_change"]

        # Track chunk length
        chunk_lengths.append(len(observation_action_pairs))

        # Skip if there are no observations
        if not observation_action_pairs:
            logger.info(f"Skipping chunk {i+1} - no observation-action pairs")
            continue

        # Format achievement_change: replace underscores with spaces and capitalize first letter
        # eg. "collect_wood" -> Collect wood
        if achievement_change:
            # Replace underscores with spaces
            formatted_instruction = str(achievement_change).replace("_", " ")[2:-2]
            # Capitalize the first letter
            if formatted_instruction:
                formatted_instruction = (
                    formatted_instruction[0].upper() + formatted_instruction[1:]
                )
        else:
            formatted_instruction = None

        # Update the instruction for all rows in this chunk
        df.loc[start_idx : end_idx - 1, "achievement_instruction"] = (
            formatted_instruction
        )

        # Create prompt w/ gpt-4o
        prompt_for_relabeling = create_prompt(
            observation_action_pairs, achievement_change
        )
        generated_instruction = get_response_openai(prompt_for_relabeling)

        # Update the instruction for all rows in this chunk
        df.loc[start_idx : end_idx - 1, "instruction"] = generated_instruction

        # Log processing details
        logger.info(f"CHUNK {i+1}:")
        logger.info(f"Chunk size (count: {len(observation_action_pairs)}):")
        for j, pair in enumerate(observation_action_pairs):
            act = pair["action"]
            logger.info(f"  Pair {j+1}:")
            logger.info(f"    Action: {act}")
        logger.info(
            f"Achievement change that triggered chunk end: {achievement_change}"
        )
        logger.info(f"Achievement instruction: {formatted_instruction}")
        logger.info(f"Generated instruction: {generated_instruction}")
        logger.info(f"Applied to rows {start_idx} to {end_idx-1}")
        logger.info("-" * 80)

        logger.info(f"Processed chunk {i+1}/{len(chunks)}")

    # Save the updated DataFrame back to CSV
    df.to_csv(new_csv_path, index=False)

    # Calculate and log average chunk length
    if chunk_lengths:
        avg_chunk_length = sum(chunk_lengths) / len(chunk_lengths)
        logger.info(f"Average chunk length: {avg_chunk_length:.2f}")
    else:
        logger.info("No chunks processed, cannot calculate average chunk length")

    logger.info(f"CSV file updated with instructions: {new_csv_path}")
    logger.info("Processing completed successfully")


def preprocess_csv_remove_invalid(
    csv_file_path: str, logger: logging.Logger = None
) -> pd.DataFrame:
    """
    Removes rows where 'valid' column is False from the CSV file

    Args:
        csv_file_path (str): Path to the CSV file
        logger (logging.Logger, optional): Logger to use for logging messages

    Returns:
        pd.DataFrame: Preprocessed DataFrame with only valid rows
    """
    # Read the CSV file
    df = pd.read_csv(csv_file_path)

    # Check if 'valid' column exists
    if "valid" not in df.columns:
        logger.info("Warning: 'valid' column not found in CSV file")
        return df

    # Remove rows where valid is False
    df_filtered = df[df["valid"] == True].copy()
    logger.info(f"Removed {len(df) - len(df_filtered)} invalid rows from CSV")
    logger.info(f"Remaining rows: {len(df_filtered)}")

    return df_filtered


def process_and_save_history(
    output_dir: str,
    save_dir: str = "goal/output/sft_dataset_final",
    logger: logging.Logger = None,
    current_time: str = None,
):
    """
    Process history column in CSV files in the output directory and save modified histories to a new CSV file.

    Args:
        output_dir (str): Directory containing CSV files to process
        save_dir (str, optional): Directory to save the processed histories. Defaults to "goal/output/sft_dataset_final".
        logger (logging.Logger, optional): Logger for logging messages
    """
    if logger is None:
        logger = logging.getLogger()

    # Initialize tokenizer
    tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
    tokenizer_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

    # Check if output directory exists
    if not os.path.exists(output_dir):
        logger.error(f"Output directory {output_dir} does not exist!")
        return

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Get all CSV files in the directory
    csv_files = [f for f in os.listdir(output_dir) if f.endswith(".csv")]

    if not csv_files:
        logger.info(f"No CSV files found in {output_dir}")
        return

    logger.info(f"Found {len(csv_files)} CSV files to process in {output_dir}")

    # List to store all processed histories
    all_histories_llama = []
    all_histories_qwen = []

    # Process each CSV file
    for csv_file in csv_files:
        file_path = os.path.join(output_dir, csv_file)
        logger.info(f"Processing file: {file_path}")

        try:
            # Read the CSV file
            df = pd.read_csv(file_path)

            # Check if required columns exist
            if "history" not in df.columns or "instruction" not in df.columns:
                logger.warning(
                    f"Required columns 'history' or 'instruction' not found in {csv_file}"
                )
                continue

            # Process each row
            for idx, row in df.iterrows():
                if pd.isna(row["history"]) or pd.isna(row["instruction"]):
                    continue

                try:
                    # Parse history from string to list of dictionaries
                    history = ast.literal_eval(row["history"])
                    instruction = row["instruction"]

                    # Find the system message
                    for item in history:
                        if item.get("role") == "system":
                            content = item["content"]

                            # Remove feedback section if exists
                            if "### Feedback from Previous Round" in content:
                                content = content.split(
                                    "### Feedback from Previous Round"
                                )[0]

                            # Add the new task instruction
                            if content.endswith("\n\n"):
                                content += (
                                    f"Now, here is the task:\nTask : {instruction}"
                                )
                            else:
                                content += (
                                    f"\n\nNow, here is the task:\nTask : {instruction}"
                                )

                            # Update the content
                            item["content"] = content

                    # Add assistant message with action content
                    if "action_goal" in row and pd.notna(row["action_goal"]):
                        action = row["action_goal"]
                        history.append({"role": "assistant", "content": f"{action}"})

                    # Tokenize history before adding to all_histories
                    tokenized_history_llama = tokenizer_llama.apply_chat_template(
                        history, tokenize=False, date_string=""
                    )
                    tokenized_history_qwen = tokenizer_qwen.apply_chat_template(
                        history, tokenize=False, date_string=""
                    )

                    # Add to the list of all histories
                    all_histories_llama.append(tokenized_history_llama)
                    all_histories_qwen.append(tokenized_history_qwen)

                except (ValueError, SyntaxError) as e:
                    logger.error(
                        f"Error parsing history in row {idx} of {csv_file}: {e}"
                    )
                    continue

        except Exception as e:
            logger.error(f"Error processing file {csv_file}: {e}")
            continue

    # Save all processed histories to a CSV file
    if all_histories_llama:
        # Create DataFrame for CSV
        histories_df = pd.DataFrame({"history": all_histories_llama})

        # Save to CSV file
        output_file = os.path.join(
            save_dir, f"processed_histories_llama_{current_time}.csv"
        )
        histories_df.to_csv(output_file, index=False)

        logger.info(
            f"Saved {len(all_histories_llama)} processed histories to {output_file}"
        )
    else:
        logger.warning("No histories were successfully processed in Llama")

    if all_histories_qwen:
        # Create DataFrame for CSV
        histories_df = pd.DataFrame({"history": all_histories_qwen})

        # Save to CSV file
        output_file = os.path.join(
            save_dir, f"processed_histories_qwen_{current_time}.csv"
        )
        histories_df.to_csv(output_file, index=False)

        logger.info(
            f"Saved {len(all_histories_qwen)} processed histories to {output_file}"
        )
    else:
        logger.warning("No histories were successfully processed in Qwen")


def main():
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Extract goals from Crafter game logs")
    parser.add_argument(
        "csv_files",
        type=str,
        nargs="+",
        help="CSV files or directories containing CSV files to process",
    )
    parser.add_argument(
        "--temp_with_instruction_dir",
        type=str,
        default="goal/output/temp_with_instruction/csv",
        help="Output directory for processed files",
    )
    parser.add_argument(
        "--temp_valid_dir",
        type=str,
        default="goal/output/temp_valid_csv",
        help="Temporary directory for filtered CSV files",
    )
    parser.add_argument(
        "--log_level",
        type=str,
        default="INFO",
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        help="Logging level (default: INFO)",
    )
    parser.add_argument(
        "--max_chunk_length",
        type=int,
        default=None,
        help="Maximum number of observation-action pairs per chunk. If not specified, there is no limit.",
    )
    parser.add_argument(
        "--process_history",
        action="store_false",
        help="Process history column and save to new file",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="goal/output/sft_dataset_final",
        help="Directory to save processed histories (default: goal/output/sft_dataset_final)",
    )

    args = parser.parse_args()

    temp_valid_dir = args.temp_valid_dir
    temp_with_instruction_dir = args.temp_with_instruction_dir

    save_dir = args.save_dir
    log_level = getattr(logging, args.log_level.upper())
    max_chunk_length = args.max_chunk_length

    # Fixed log directory
    log_dir = os.path.join("goal/output/temp_with_instruction", "logs")
    os.makedirs(log_dir, exist_ok=True)

    # Get CSV filename for the log file name
    csv_filenames = []
    for path in args.csv_files:
        if os.path.isdir(path):
            # If path is a directory, use directory name
            csv_filenames.append(os.path.basename(path))
        else:
            # If path is a file, use file name without extension
            if path.endswith(".csv"):
                csv_filenames.append(os.path.splitext(os.path.basename(path))[0])

    # Join CSV filenames with underscore for the log file name
    csv_name_part = "_".join(csv_filenames)
    # Limit the length to avoid very long filenames
    if len(csv_name_part) > 50:
        csv_name_part = csv_name_part[:50]

    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # Set up global logging configuration with a main log file
    main_log_file = os.path.join(log_dir, f"{csv_name_part}_{current_time}.log")

    # Configure root logger
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler(main_log_file), logging.StreamHandler()],
    )
    logger = logging.getLogger()
    logger.info(f"Starting process with global log file: {main_log_file}")

    # Get all CSV files to process
    all_csv_files = []
    for path in args.csv_files:
        if os.path.isdir(path):
            # If path is a directory, get all CSV files in it
            for file in os.listdir(path):
                if file.endswith(".csv"):
                    all_csv_files.append(os.path.join(path, file))
        else:
            # If path is a file, add it directly
            if path.endswith(".csv"):
                all_csv_files.append(path)

    if not all_csv_files:
        logger.info("No CSV files found to process!")
        return

    logger.info(f"Found {len(all_csv_files)} CSV files to process")

    # Process each CSV file
    for i, csv_file in enumerate(all_csv_files):
        logger.info(f"Processing file {i+1}/{len(all_csv_files)}: {csv_file}")

        # Create temp directory if it doesn't exist
        os.makedirs(temp_valid_dir, exist_ok=True)

        # Preprocess CSV to remove invalid rows
        temp_csv_path = os.path.join(
            temp_valid_dir, str(os.path.basename(csv_file)) + "_filtered.csv"
        )
        df_filtered = preprocess_csv_remove_invalid(csv_file, logger)

        if df_filtered.empty:
            logger.info(
                f"Warning: Filtered DataFrame for {csv_file} is empty! Skipping this file."
            )
            continue

        logger.info(f"Filtered CSV saved to: {temp_csv_path}")
        # Save filtered DataFrame to temporary CSV
        df_filtered.to_csv(temp_csv_path, index=False)

        # Process observations and update CSV using filtered data
        process_observations_and_update_csv(
            csv_file_path=temp_csv_path,
            output_dir=temp_with_instruction_dir,
            logger=logger,
            max_chunk_length=max_chunk_length,
        )

        # Clean up temporary file if needed
        # os.remove(temp_csv_path)

    # Process history after all files have been processed
    if args.process_history:
        logger.info("Inserting history data to CSV files...")
        process_and_save_history(
            temp_with_instruction_dir, save_dir, logger, current_time
        )

    logger.info("All files processed successfully.")


if __name__ == "__main__":
    main()
