from datasets import Dataset
from typing import List, Tuple, Literal
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from prompt.text_time_series_prompt import TextTimeSeriesPrompt
from time_series_datasets.QADataset import QADataset
from time_series_datasets.har_cot.har_cot_loader import load_har_cot_splits
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from time_series_datasets.util import (
    extend_time_series_to_match_patch_size_and_aggregate,
)
import numpy as np

TIME_SERIES_LABELS = [
    "The following is the accelerometer data on the x-axis",
    "The following is the accelerometer data on the y-axis", 
    "The following is the accelerometer data on the z-axis",
]

class HARCoTQADataset(QADataset):
    def __init__(self, split: Literal["train", "test", "validation"], EOS_TOKEN: str, format_sample_str: bool = False, time_series_format_function=None):
        super().__init__(split, EOS_TOKEN, format_sample_str, time_series_format_function)
    
    def _load_splits(self) -> Tuple[Dataset, Dataset, Dataset]:
        """
        Load the HAR CoT dataset splits using the har_cot_loader.
        
        Returns:
            Tuple of (train, validation, test) datasets
        """
        return load_har_cot_splits()

    def _get_answer(self, row) -> str:
        """
        Get the answer from the row, which is the chain-of-thought reasoning.
        
        Args:
            row: Dataset row
            
        Returns:
            Chain-of-thought reasoning as a string
        """
        return row["rationale"]

    def _get_pre_prompt(self, _row) -> str:
        """
        Get the pre-prompt text.
        
        Args:
            _row: Dataset row (unused)
            
        Returns:
            Pre-prompt text
        """
        text = """
        You are given accelerometer data in all three dimensions. Your task is to classify the activity based on analysis of the data.

        Instructions:
        - Begin by analyzing the time series without assuming a specific label.
        - Think step-by-step about what the observed patterns suggest regarding movement intensity and behavior.
        - Write your rationale as a single, natural paragraph — do not use bullet points, numbered steps, or section headings.
        - Do **not** mention any class label until the final sentence.

        Possible activity labels are:
        biking, lying, running, sitting, standing, walking, walking_down, walking_up.
        
        - Make sure that your last word is the answer. You MUST end your response with "Answer: "
        """
        return text

    def _get_post_prompt(self, _row) -> str:
        """
        Get the post-prompt text.
        
        Args:
            _row: Dataset row (unused)
            
        Returns:
            Post-prompt text
        """
        return "Rationale:"

    def _get_text_time_series_prompt_list(self, row) -> List[TextTimeSeriesPrompt]:
        """
        Convert the time series data into a list of TextTimeSeriesPrompt objects, including mean and std in the text.
        """
        # Extract the time series data from the row
        series = torch.tensor(
            [
                row["x_axis"],
                row["y_axis"],
                row["z_axis"],
            ],
            dtype=torch.float32,
        )

        # Check for invalid data
        if torch.isnan(series).any() or torch.isinf(series).any():
            print(f"❌ Invalid data detected in HAR CoT sample")
            print(f"Row data: {row}")
            print(f"Series shape: {series.shape}")
            print(f"Series values: {series}")
            print(f"NaN positions: {torch.isnan(series).nonzero()}")
            print(f"Inf positions: {torch.isinf(series).nonzero()}")
            raise ValueError("Invalid data detected")

        # Normalize the data with better numerical stability
        means = series.mean(dim=1, keepdim=True)
        stds = series.std(dim=1, keepdim=True)
        
        # Handle zero or very small standard deviations
        min_std = 1e-6  # Increased from 1e-8 for better stability
        stds = torch.clamp(stds, min=min_std)
        
        series_norm = (series - means) / stds
        
        # Check for NaN/Inf after normalization
        if torch.isnan(series_norm).any() or torch.isinf(series_norm).any():
            print(f"❌ NaN/Inf detected after normalization")
            print(f"Original series: {series}")
            print(f"Original series shape: {series.shape}")
            print(f"Row data: {row}")
            print(f"Means: {means}")
            print(f"Stds: {stds}")
            print(f"Normalized series: {series_norm}")
            print(f"NaN positions: {torch.isnan(series_norm).nonzero()}")
            print(f"Inf positions: {torch.isinf(series_norm).nonzero()}")
            raise ValueError("NaN/Inf detected after normalization")

        prompts = []
        for i, (time_series_label, time_series, mean, std) in enumerate(zip(
            TIME_SERIES_LABELS, 
            series_norm.tolist(), 
            means.squeeze().tolist(), 
            stds.squeeze().tolist()
        )):
            text_prompt = f"{time_series_label}, it has mean {mean:.4f} and std {std:.4f}:"
            prompts.append(TextTimeSeriesPrompt(text_prompt, time_series))
        return prompts

    @staticmethod
    def get_labels() -> List[str]:
        """
        Return the list of all possible activity labels for the HARCoTQADataset.
        """
        return [
            "biking",
            "lying", 
            "running",
            "sitting",
            "standing",
            "walking",
            "walking_down",
            "walking_up",
        ]

    def _format_sample(self, row):
        sample = super()._format_sample(row)
        sample["label"] = row["label"]
        sample["x_axis"] = row["x_axis"]
        sample["y_axis"] = row["y_axis"]
        sample["z_axis"] = row["z_axis"]
        return sample

if __name__ == "__main__":
    dataset = HARCoTQADataset(split="train", EOS_TOKEN="")
    dataset_val = HARCoTQADataset(split="validation", EOS_TOKEN="")
    dataset_test = HARCoTQADataset(split="test", EOS_TOKEN="")

    print(f"Dataset sizes: Train: {len(dataset)}, Validation: {len(dataset_val)}, Test: {len(dataset_test)}")

    # Show sample data
    if len(dataset) > 0:
        print("\n" + "="*50 + "\n")
        print("Sample data from training set:")
        sample = dataset[0]
        print("Sample keys:", sample.keys())
        print("Sample label:", sample["label"])
        print("Sample answer length:", len(sample["answer"]))
        print("Sample answer preview:", sample["answer"][:200] + "..." if len(sample["answer"]) > 200 else sample["answer"])
        print("Sample time series text:", sample["time_series_text"] if "time_series_text" in sample else "N/A")
        print("Sample pre prompt:", sample["pre_prompt"][:200] + "..." if len(sample["pre_prompt"]) > 200 else sample["pre_prompt"])
        print("Sample post prompt:", sample["post_prompt"]) 