import os
import json
from typing import Tuple, List, Dict, Optional
import numpy as np
import pydantic

import torch
from torch.utils.data import IterableDataset, get_worker_info

# --- START: Added imports for plotting ---
#import matplotlib
#matplotlib.use('Agg') # Use 'Agg' backend for non-interactive saving
#import matplotlib.pyplot as plt
import pandas as pd
# --- END: Added imports ---

from models.losses import IGNORE_LABEL_ID
from dataset.common import PuzzleDatasetMetadata

from argdantic import ArgParser
from pydantic import BaseModel

# --- START: Added functions for path generation ---
def corrupt_linear_path(x_0: torch.Tensor, x_1: torch.Tensor, num_steps: int) -> Tuple[List[torch.Tensor], None]:
    """Generates a sequence of corrupted grids interpolating between input (x_0) and target (x_1)."""
    if x_0.shape != x_1.shape:
        raise ValueError("Input and target tensors must have the same shape.")

    # Works with both 1D (flattened) and 2D tensors
    diff_mask = (x_0 != x_1).squeeze()
    diff_indices = torch.where(diff_mask)
    num_diff = len(diff_indices[0])

    if num_diff == 0:
        # --- FIX 1: Always return a tuple of 2 ---
        return [x_0.clone() for _ in range(num_steps)], None

    frames = []
    perm = torch.randperm(num_diff)
    permuted_diff_indices = tuple(d[perm] for d in diff_indices)

    for k in range(num_steps):
        t = k / (num_steps - 1)
        num_to_replace = int(t * num_diff)
        
        current_frame = x_0.clone()
        if num_to_replace > 0:
            indices_to_replace = tuple(d[:num_to_replace] for d in permuted_diff_indices)
            # Create a view for replacement
            target_view = x_1.squeeze()
            frame_view = current_frame.squeeze()
            frame_view[indices_to_replace] = target_view[indices_to_replace]
        
        frames.append(current_frame)
    # --- FIX 2: Always return a tuple of 2 ---
    return frames, None

def batch_corrupt_linear_path(x_0_batch: torch.Tensor, x_1_batch: torch.Tensor, num_steps: int) -> torch.Tensor:
    """Generates corruption paths for a whole batch of inputs and targets in a traceable way."""
    batch_size = x_0_batch.shape[0]
    seq_len = x_0_batch.shape[1]
    
    all_paths = torch.empty(
        (batch_size, num_steps, seq_len),
        dtype=x_0_batch.dtype,
        device=x_0_batch.device
    )
    
    # This loop is over a static range, which torch.compile can handle.
    for i in range(batch_size):
        frames, _ = corrupt_linear_path(
            x_0=x_0_batch[i],
            # --- FIX 3: Use x_1_batch, not x_0_batch ---
            x_1=x_1_batch[i], 
            num_steps=num_steps
        )
        if frames:
            all_paths[i] = torch.stack(frames).squeeze(1)
    return all_paths
# --- END: Added functions ---


def _sample_batch(rng: np.random.Generator, group_order: np.ndarray, puzzle_indices: np.ndarray, group_indices: np.ndarray, start_index: int, global_batch_size: int):
    # Pack examples into a full batch
    batch = []
    batch_puzzle_indices = []
    current_size = 0

    while (start_index < group_order.size) and (current_size < global_batch_size):
        # Pick a group and a puzzle from that group
        group_id = group_order[start_index]
        puzzle_id = rng.integers(group_indices[group_id], group_indices[group_id + 1])
        start_index += 1

        # Get range of the puzzle
        puzzle_start = puzzle_indices[puzzle_id]
        puzzle_size = int(puzzle_indices[puzzle_id + 1] - puzzle_start)

        append_size = min(puzzle_size, global_batch_size - current_size)

        # Put into batch
        batch_puzzle_indices.append(np.full(append_size, puzzle_id, dtype=np.int32))
        batch.append(puzzle_start + np.random.choice(puzzle_size, append_size, replace=False))

        current_size += append_size

    return start_index, np.concatenate(batch), np.concatenate(batch_puzzle_indices)


class PuzzleDatasetConfig(pydantic.BaseModel):
    seed: int
    dataset_paths: List[str]
    global_batch_size: int
    test_set_mode: bool
    epochs_per_iter: int  # Batch X epochs in an iteration to reduce overhead.
    rank: int
    num_replicas: int

class PuzzleDataset(IterableDataset):
    def __init__(self, config: PuzzleDatasetConfig, split: str = "train"):
        super().__init__()
        self.config = config
        self.split = split

        # Merge multiple metadata
        prev_seq_len = None
        prev_vocab_size = None
        prev_pad_id = None
        prev_ignore_label_id = None
        prev_blank_identifier_id = None
        prev_sets = None
        prev_num_identifiers = None
        mean_puzzle_examples = 0
        total_puzzles = 0
        total_groups = 0
        num_identifiers = 0
        for dataset_path in config.dataset_paths:
            current_metadata = self._load_metadata(dataset_path)
            if prev_seq_len is None:
                prev_seq_len = current_metadata.seq_len
                prev_vocab_size = current_metadata.vocab_size
                prev_pad_id = current_metadata.pad_id
                prev_ignore_label_id = current_metadata.ignore_label_id
                prev_blank_identifier_id = current_metadata.blank_identifier_id
                prev_sets = current_metadata.sets
                prev_num_identifiers = current_metadata.num_puzzle_identifiers
            else:
                assert prev_seq_len == current_metadata.seq_len
                assert prev_vocab_size == current_metadata.vocab_size
                assert prev_pad_id == current_metadata.pad_id
                assert prev_ignore_label_id == current_metadata.ignore_label_id
                assert prev_blank_identifier_id == current_metadata.blank_identifier_id
                assert prev_sets == current_metadata.sets
                assert prev_num_identifiers == current_metadata.num_puzzle_identifiers
            
            # --- MODIFIED: Handle zero puzzles in test set ---
            if current_metadata.total_puzzles > 0:
                mean_puzzle_examples += current_metadata.mean_puzzle_examples*current_metadata.total_puzzles
                total_puzzles += current_metadata.total_puzzles
            # --- END MODIFICATION ---
            
            total_groups += current_metadata.total_groups
            num_identifiers += current_metadata.num_puzzle_identifiers
        
        # --- MODIFIED: Prevent ZeroDivisionError ---
        if total_puzzles > 0:
            mean_puzzle_examples = mean_puzzle_examples / total_puzzles
        else:
            mean_puzzle_examples = 0
        # --- END MODIFICATION ---

        self.metadata = PuzzleDatasetMetadata(
            seq_len=prev_seq_len,
            vocab_size=prev_vocab_size,
            pad_id=prev_pad_id,
            ignore_label_id=prev_ignore_label_id,
            blank_identifier_id=prev_blank_identifier_id,
            num_puzzle_identifiers=num_identifiers,
            total_groups=total_groups,
            mean_puzzle_examples=mean_puzzle_examples,
            total_puzzles=total_puzzles,
            sets=prev_sets
        )

        # Checks
        assert self.config.global_batch_size % self.config.num_replicas == 0, f"Global batch size {self.config.global_batch_size} must be multiples of nodes {self.config.num_replicas}."
        self.local_batch_size = self.config.global_batch_size // self.config.num_replicas

        # State
        self._data = None
        self._iters = 0

    def _load_metadata(self, dataset_path) -> PuzzleDatasetMetadata:
        metadata_path = os.path.join(dataset_path, self.split, "dataset.json")
        if not os.path.exists(metadata_path):
            # Handle missing test split gracefully
            print(f"Warning: Metadata file not found: {metadata_path}")
            # Return a dummy metadata with 0 puzzles
            return PuzzleDatasetMetadata(
                seq_len=1, vocab_size=1, pad_id=0, ignore_label_id=-100,
                blank_identifier_id=-1, num_puzzle_identifiers=0,
                total_groups=0, mean_puzzle_examples=0, total_puzzles=0, sets=[]
            )
        with open(metadata_path, "r") as f:
            return PuzzleDatasetMetadata(**json.load(f))

    def _lazy_load_dataset(self):
        if self._data is not None:
            return

        field_mmap_modes = {
            "inputs": "r",
            "labels": "r",
            
            # --- START MODIFICATION: Load into memory (None) instead of mmap ("r") ---
            # This allows us to modify the arrays in-place during validation.
            "intermediate_steps": None, # "r",
            "intermediate_mask": None,  # "r",
            # --- END MODIFICATION ---

            # Keep indices in memory
            "puzzle_identifiers": None,
            "puzzle_indices": None,
            "group_indices": None
        }

        # Load data
        self._data = {}
        for set_name in self.metadata.sets: # Load subset
            for i, dataset_path in enumerate(self.config.dataset_paths):
                if i > 0:
                    set_name_ = set_name + str(i)
                else:
                    set_name_ = set_name
                
                # Check which files exist before trying to load
                current_set_data = {}
                for field_name, mmap_mode in field_mmap_modes.items():
                    file_path = os.path.join(dataset_path, self.split, f"{set_name}__{field_name}.npy")
                    if os.path.exists(file_path):
                        # --- START MODIFICATION: Ensure we can write to intermediate steps ---
                        writeable_mmap = mmap_mode
                        if self.split == "train" and field_name in ["intermediate_steps", "intermediate_mask"]:
                             # If training, load as writeable copy ('c') instead of read-only ('r')
                             # or load into memory (None)
                             if writeable_mmap == "r":
                                 writeable_mmap = "c" 
                        
                        current_set_data[field_name] = np.load(file_path, mmap_mode=writeable_mmap)
                        
                        # If we loaded as 'c', we must copy to make it writeable
                        if writeable_mmap == "c":
                            current_set_data[field_name] = current_set_data[field_name].copy()
                        # --- END MODIFICATION ---
                            
                    else:
                        # --- MODIFIED: Conditionally load intermediate steps ---
                        if field_name in ["intermediate_steps", "intermediate_mask"]:
                            print(f"Info: Optional file not found, skipping: {file_path}")
                        elif field_name in ["puzzle_identifiers", "puzzle_indices", "group_indices"] and self.config.test_set_mode:
                             print(f"Info: Optional index file not found in test mode, skipping: {file_path}")
                        else:
                            # Raise error for essential files
                            raise FileNotFoundError(f"Essential dataset file not found: {file_path}")
                        # --- END MODIFICATION ---
                
                self._data[set_name_] = current_set_data

        # --- START: NEW VALIDATION/REGENERATION BLOCK ---
        # After all data is loaded, validate and fix it.
        print("--- Running Dataset Validation and Regeneration ---")
        
        # For plotting
        plot_data = []
        
        for set_name, dataset in self._data.items():
            # Only validate training sets that have intermediate steps
            if self.split == "train" and "intermediate_steps" in dataset:
                
                # Load data into torch tensors for validation
                # Note: We are modifying the numpy arrays in 'dataset'
                inputs_np = dataset["inputs"]
                labels_np = dataset["labels"]
                paths_np = dataset["intermediate_steps"]
                
                # Ensure data is writeable (it should be now)
                if not inputs_np.flags.writeable: inputs_np = inputs_np.copy()
                if not labels_np.flags.writeable: labels_np = labels_np.copy()
                if not paths_np.flags.writeable: paths_np = paths_np.copy()
                
                # Move to torch for computation
                inputs_torch = torch.from_numpy(inputs_np).to(torch.int32)
                labels_torch = torch.from_numpy(labels_np).to(torch.int32)
                paths_torch = torch.from_numpy(paths_np).to(torch.int32)
                
                total_paths = len(inputs_np)

                # Check 1: Final step in path must match the label
                final_step = paths_torch[:, -1, :]
                target_mismatch = (final_step != labels_torch).any(dim=1)

                # Check 2: Path is stagnant (e.g., step 1 is same as input)
                input_mismatch_label = (inputs_torch != labels_torch).any(dim=1)
                step_1 = paths_torch[:, 1, :]
                step_1_equals_input = (step_1 == inputs_torch).all(dim=1)
                is_stagnant = input_mismatch_label & step_1_equals_input
                
                # Find which items need fixing
                indices_to_fix = torch.where(target_mismatch | is_stagnant)[0]
                invalid_paths = indices_to_fix.numel()
                
                # Add data for plotting
                plot_data.append({"set": set_name, "type": "Invalid", "count": invalid_paths})
                plot_data.append({"set": set_name, "type": "Valid", "count": total_paths - invalid_paths})
                
                if invalid_paths > 0:
                    print(f"  Found {invalid_paths} invalid paths out of {total_paths} in {set_name}. Regenerating...")
                    
                    # Get inputs/labels for *only* these bad items
                    inputs_to_fix = inputs_torch[indices_to_fix]
                    labels_to_fix = labels_torch[indices_to_fix]
                    
                    num_steps = paths_torch.shape[1] # Match original path length
                    
                    # Regenerate paths
                    regenerated_paths = batch_corrupt_linear_path(
                        inputs_to_fix, 
                        labels_to_fix, 
                        num_steps
                    )
                    
                    # Overwrite the bad paths in the original numpy array
                    dataset["intermediate_steps"][indices_to_fix.cpu().numpy()] = regenerated_paths.cpu().numpy()
                    
                    # --- START: ERROR FIX ---
                    # Overwrite the mask to be all 1s (valid)
                    mask_np = dataset["intermediate_mask"]
                    if not mask_np.flags.writeable: mask_np = mask_np.copy()
                    
                    # Get the slice we are about to overwrite
                    mask_slice_to_overwrite = mask_np[indices_to_fix.cpu().numpy()]
                    
                    # Create an array of ones with the *same shape as the slice*
                    new_mask_values = np.ones_like(mask_slice_to_overwrite, dtype=np.int32)
                    
                    # Assign the new values
                    mask_np[indices_to_fix.cpu().numpy()] = new_mask_values
                    dataset["intermediate_mask"] = mask_np
                    # --- END: ERROR FIX ---
                    
                    print(f"  Regeneration complete for {set_name}.")
                else:
                    print(f"  All {total_paths} paths in {set_name} are valid.")
            else:
                print(f"Skipping validation for {set_name} (not a train split or no intermediate steps).")
        
        #print("--- Dataset Validation Complete ---")

        # # --- START: New Plotting Logic ---
        # if plot_data:
        #     try:
        #         df = pd.DataFrame(plot_data)
                
        #         # Pivot for stacking
        #         pivot_df = df.pivot(index='set', columns='type', values='count').fillna(0)
        #         # Ensure correct order for stacking
        #         if 'Invalid' not in pivot_df: pivot_df['Invalid'] = 0
        #         if 'Valid' not in pivot_df: pivot_df['Valid'] = 0
                
        #         pivot_df = pivot_df[['Valid', 'Invalid']] # Stack 'Valid' first, then 'Invalid'
                
        #         # Create the plot
        #         ax = pivot_df.plot(
        #             kind='bar', 
        #             stacked=True, 
        #             color={'Valid': '#2ca02c', 'Invalid': '#d62728'}, # Green/Red
        #             figsize=(10, 6),
        #             rot=0
        #         )
                
        #         plt.title('Training Data Path Validation Summary', fontsize=16)
        #         plt.ylabel('Number of Paths', fontsize=12)
        #         plt.xlabel('Dataset Set', fontsize=12)
        #         plt.legend(title='Path Type', bbox_to_anchor=(1.02, 1), loc='upper left')
        #         plt.tight_layout(rect=[0, 0, 0.85, 1]) # Make room for legend
                
        #         # Add text labels for counts
        #         for c in ax.containers:
        #             # Filter out containers with no value
        #             labels = [f'{v.get_height():,}' if v.get_height() > 0 else '' for v in c]
        #             ax.bar_label(c, labels=labels, label_type='center', color='white', fontweight='bold')
                
        #         plt.savefig("path_validation_summary.png")
        #         print("Saved path validation summary plot to path_validation_summary.png")

        #     except Exception as e:
        #         print(f"Warning: Failed to generate validation plot. Error: {e}")
        # --- END: New Plotting Logic ---
        # --- END: NEW VALIDATION/REGENERATION BLOCK ---


    def _collate_batch(self, batch):
        # Convert dtype
        # Handle different dtypes (mask is bool/int, others int)
        for k, v in batch.items():
            if k == "intermediate_mask":
                batch[k] = v.astype(np.int32) # Use int for padding
            elif k in ["inputs", "labels", "puzzle_identifiers", "intermediate_steps"]:
                batch[k] = v.astype(np.int32)


        # Convert ignore label IDs
        if self.metadata.ignore_label_id is not None and "labels" in batch:
            batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID

        # Pad
        # Find the max size from a key that is always present (e.g., 'inputs')
        # We need to find the effective batch size *before* padding
        
        # --- MODIFIED: Check for 'puzzle_identifiers' before padding ---
        # 'puzzle_identifiers' might not be present in test mode if not found
        key_for_size = "puzzle_identifiers" if "puzzle_identifiers" in batch else "inputs"
        if key_for_size not in batch:
             # If no data in batch (e.g., empty test set), just return empty dict
             if not batch:
                 return {}
             # This should not be empty, but as a safeguard
             key_for_size = next(iter(batch.keys()))

        current_batch_size = batch[key_for_size].size if batch[key_for_size].ndim == 1 else len(batch[key_for_size])

        if current_batch_size < self.local_batch_size:
            pad_size = self.local_batch_size - current_batch_size
            pad_values = {
                "inputs": self.metadata.pad_id,
                "labels": IGNORE_LABEL_ID,
                "puzzle_identifiers": self.metadata.blank_identifier_id,
                "intermediate_steps": self.metadata.pad_id, # <-- ADDED
                "intermediate_mask": 0                      # <-- ADDED (0 for False)
            }
            
            batch = {
                k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values.get(k, 0))
                for k, v in batch.items()
            }
        # --- END MODIFICATION ---

        # To tensor
        return {k: torch.from_numpy(v) for k, v in batch.items()}
    
    def _iter_test(self):
        for set_i, (set_name, dataset) in enumerate(self._data.items()):  # type: ignore
            if "inputs" not in dataset:
                print(f"Skipping empty dataset set: {set_name}")
                continue
                
            total_examples = len(dataset["inputs"])

            # --- MODIFIED: Handle missing indices in test mode ---
            has_indices = "puzzle_indices" in dataset and "puzzle_identifiers" in dataset
            if not has_indices:
                print(f"Warning: Missing 'puzzle_indices' or 'puzzle_identifiers' for test set {set_name}. Iterating example-by-example.")
            # --- END MODIFICATION ---

            # Load examples one by one
            start_index = 0
            while start_index < total_examples:
                # Compute indices
                end_index = min(total_examples, start_index + self.config.global_batch_size)
                
                local_start = start_index + self.config.rank * self.local_batch_size
                local_end   = min(start_index + (self.config.rank + 1) * self.local_batch_size, end_index)
                
                if local_start >= local_end:
                    start_index += self.config.global_batch_size
                    continue

                # Build the batch dictionary
                batch_data = {
                    "inputs": dataset["inputs"][local_start: local_end],
                    "labels": dataset["labels"][local_start: local_end],
                }

                # --- MODIFIED: Conditionally add puzzle_identifiers ---
                if has_indices:
                    # Get batch of examples, and also puzzle IDs
                    puzzle_indices_list = []
                    puzzle_index = np.searchsorted(dataset["puzzle_indices"], local_start, side="right") - 1
                    for i in range(local_start, local_end):
                        while puzzle_index + 1 < len(dataset["puzzle_indices"]) and i >= dataset["puzzle_indices"][puzzle_index + 1]:
                            puzzle_index += 1
                        puzzle_indices_list.append(puzzle_index)
                    
                    batch_data["puzzle_identifiers"] = dataset["puzzle_identifiers"][puzzle_indices_list]
                else:
                    # Create dummy identifiers if missing
                    batch_data["puzzle_identifiers"] = np.full(local_end - local_start, self.metadata.blank_identifier_id)
                # --- END MODIFICATION ---

                # --- MODIFIED: Conditionally add intermediate steps ---
                if "intermediate_steps" in dataset:
                    batch_data["intermediate_steps"] = dataset["intermediate_steps"][local_start: local_end]
                if "intermediate_mask" in dataset:
                    batch_data["intermediate_mask"] = dataset["intermediate_mask"][local_start: local_end]
                # --- END MODIFICATION ---
                
                batch = self._collate_batch(batch_data)

                yield set_name, batch, end_index - start_index
                
                # Advance to next batch
                start_index += self.config.global_batch_size

    def _iter_train(self):
        for set_name, dataset in self._data.items():  # type: ignore
            # Increase epoch count
            self._iters += 1

            # Randomly shuffle groups
            rng = np.random.Generator(np.random.Philox(seed=self.config.seed + self._iters))

            group_order = np.concatenate([rng.permutation(dataset["group_indices"].size - 1) for _i in range(self.config.epochs_per_iter)])
            start_index = 0
            
            while start_index < group_order.size:
                start_index, batch_indices, batch_puzzle_indices = _sample_batch(
                    rng,
                    group_order=group_order,
                    puzzle_indices=dataset["puzzle_indices"],
                    group_indices=dataset["group_indices"],
                    start_index=start_index,
                    global_batch_size=self.config.global_batch_size,
                )

                # Select current rank and collate
                global_effective_batch_size = batch_puzzle_indices.size  # Global effective batch size, excluding pads

                # Drop last batch
                if global_effective_batch_size < self.config.global_batch_size:
                    break

                batch_indices        = batch_indices       [self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
                batch_puzzle_indices = batch_puzzle_indices[self.config.rank * self.local_batch_size: (self.config.rank + 1) * self.local_batch_size]
                
                # Build the batch dictionary
                batch_data = {
                    "inputs": dataset["inputs"][batch_indices],
                    "labels": dataset["labels"][batch_indices],
                    "puzzle_identifiers": dataset["puzzle_identifiers"][batch_puzzle_indices]
                }

                # --- MODIFIED: Conditionally add intermediate steps ---
                if "intermediate_steps" in dataset:
                    batch_data["intermediate_steps"] = dataset["intermediate_steps"][batch_indices]
                if "intermediate_mask" in dataset:
                    batch_data["intermediate_mask"] = dataset["intermediate_mask"][batch_indices]
                # --- END MODIFICATION ---

                batch = self._collate_batch(batch_data)

                yield set_name, batch, global_effective_batch_size
                
    def __iter__(self):
        worker_info = get_worker_info()
        assert worker_info is None or worker_info.num_workers == 1, "Multithreaded data loading is not currently supported."
        
        self._lazy_load_dataset()
        
        # --- MODIFIED: Handle case where no data was loaded (e.g., empty test) ---
        if not self._data:
            print("Warning: No data loaded, dataset iterator will be empty.")
            return
        # --- END MODIFICATION ---
        
        # Iterate using specified mode
        if self.config.test_set_mode:
            yield from self._iter_test()
        else:
            yield from self._iter_train()