from dataclasses import dataclass, field
from pathlib import Path  # Import Path
import torch
import torch.nn as nn
from transformers import Trainer, TrainingArguments, get_scheduler
from transformers.utils import logging
from typing import Optional, Dict, Union, Any, Tuple, List
import math
import os
import numpy as np

from loader.models.my_gpt2 import MyGPT2, sinkhorn  # Import the custom model and sinkhorn

import wandb
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
import matplotlib.pyplot as plt
import io

logger = logging.get_logger(__name__)


# Define dummy classes
class DummyScheduler:
    def step(self):
        pass

    def state_dict(self):
        return {}

    def load_state_dict(self, state_dict):
        pass

    def get_last_lr(self):
        # Return a value that doesn't cause issues if checked
        return [0.0]


class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        # Need dummy params and defaults for Optimizer base class
        super().__init__([torch.nn.Parameter(torch.empty(0))], {})

    def step(self, closure=None):
        # Do nothing
        if closure is not None:
            closure()
        pass

    def zero_grad(self, set_to_none: bool = False):
        # Do nothing (actual zero_grad is called on optimizer1/2)
        pass


# Define custom TrainingArguments if needed, or reuse existing ones
@dataclass
class MyTrainingArguments(TrainingArguments):
    """
    Custom TrainingArguments extending Hugging Face's TrainingArguments
    to include specific parameters for two-stage training and other custom settings.
    """

    # --- Stage Parameters ---
    # Stage-specific epoch/step counts removed. Use standard num_train_epochs or max_steps.
    transformer_lr: Optional[float] = field(
        default=5e-5, metadata={"help": "Learning rate for transformer weights (stage 1)"}
    )
    permutation_lr: Optional[float] = field(
        default=1e-4, metadata={"help": "Learning rate for permutation matrix P (stage 2)"}
    )

    # --- Custom Experiment/Data Parameters (Moved from argparse) ---
    data_path: str = field(
        default="./data/data_sum", metadata={"help": "Path prefix for data files (e.g., ./data/wikitext-103/wiki)"}
    )
    data_encoding: str = field(default="infix", metadata={"help": "Data encoding (should be compatible with GPT-2)"})
    # output_dir is already defined in TrainingArguments
    save_periodic: int = field(
        default=0,
        metadata={
            "help": "Save the model periodically (0 to disable)"
        },  # Might not be needed if using save_strategy/steps
    )
    exp_name: str = field(default="permutation_learning", metadata={"help": "Experiment name"})
    exp_id: str = field(default="", metadata={"help": "Experiment ID"})
    model_name_or_path: str = field(default="gpt2", metadata={"help": "Base model name or path (e.g., gpt2)"})
    max_sequence_length: int = field(
        default=1024, metadata={"help": "Maximum sequence length (should match model config)"}
    )
    # --- Batch Size and Workers (Mapping from argparse) ---
    # Keep these for mapping from the shell script, but recommend adjusting script later
    num_batch: Optional[int] = field(
        default=10,
        metadata={
            "help": "Total batch size for training (will be divided by num GPUs). If set, overrides per_device_train_batch_size calculated from this."
        },
    )
    test_batch_size: Optional[int] = field(
        default=None,
        metadata={
            "help": "Total batch size for evaluation (will be divided by num GPUs). If set, overrides per_device_eval_batch_size calculated from this."
        },
    )
    num_workers: int = field(
        default=0, metadata={"help": "Number of CPU workers for DataLoader (used if dataloader_num_workers not set)"}
    )
    # --- Vocab/Tokenizer Parameters (from main.py) ---
    num_variables: int = field(default=2, metadata={"help": "Number of variables for polynomial tasks"})
    polynomial_field: str = field(default="QQ", metadata={"help": "Field for polynomial tasks (e.g., QQ, GF7)"})
    max_coefficient: int = field(default=1000, metadata={"help": "Maximum coefficient for polynomial tasks"})
    max_degree: int = field(default=10, metadata={"help": "Maximum degree for polynomial tasks"})
    # Explicitly define dry_run even though it should be inherited, to ensure parser recognizes it.
    dry_run: bool = field(default=False, metadata={"help": "Quick test for debugging purposes."})
    # Ensure columns needed by the collator are not removed
    remove_unused_columns: bool = field(
        default=False, metadata={"help": "Whether or not to automatically remove unused columns from the dataset."}
    )
    bf16: bool = field(
        default=False,
        metadata={
            "help": "Whether to use bf16 (mixed precision) instead of fp16. Requires PyTorch >= 1.10 and GPU support."
        },
    )
    per_device_train_batch_size: Optional[int] = field(
        default=128,
        metadata={"help": "Batch size per device during training. If set, overrides num_batch."},
    )
    per_device_eval_batch_size: Optional[int] = field(
        default=100,
        metadata={"help": "Batch size per device during evaluation. If set, overrides test_batch_size."},
    )
    lr_scheduler_type: str = field(default="linear", metadata={"help": "Learning rate scheduler type."})
    stage1_scheduler_type: str = field(
        default="linear", metadata={"help": "Scheduler type for stage 1 (used when creating the manual scheduler)."}
    )
    stage2_scheduler_type: str = field(
        default="linear", metadata={"help": "Scheduler type for stage 2 (used when creating the manual scheduler)."}
    )
    # Gumbel-Sinkhorn parameters (for stage 2)
    gumbel_sinkhorn_tau: float = field(
        default=0.2, metadata={"help": "Temperature parameter for Gumbel-Sinkhorn. Trying a lower value again."}
    )
    gumbel_sinkhorn_iters: int = field(
        default=20, metadata={"help": "Number of Sinkhorn iterations for Gumbel-Sinkhorn."}
    )
    # Optimizer specific settings per stage (User reverted these, keeping args for context if needed later)
    # stage1_optim: str = field(
    #     default="adamw_torch", metadata={"help": "Optimizer to use for stage 1 (transformer weights)."}
    # )
    # stage2_optim: str = field(
    #     default="sgd", metadata={"help": "Optimizer to use for stage 2 (permutation matrix P)."}
    # )
    # stage2_momentum: float = field(
    #     default=0.9, metadata={"help": "Momentum for SGD optimizer in stage 2."}
    # )

    def __post_init__(self):
        super().__post_init__()
        # Default learning_rate for Trainer logic
        if self.learning_rate is None and self.transformer_lr is not None:
            self.learning_rate = self.transformer_lr
        # Map num_batch/test_batch_size to per_device if per_device not set directly
        # This calculation should ideally happen *after* parsing and potential GPU detection
        # We will do this in the main script now.

        # Set dataloader_num_workers if not set via command line
        if self.dataloader_num_workers is None:
            self.dataloader_num_workers = self.num_workers

        # Ensure only the latest checkpoint is kept by default
        if self.save_total_limit is None:
            self.save_total_limit = 1


class MyTrainer(Trainer):
    """Custom Trainer for step-by-step two-stage optimization."""

    def __init__(
        self,
        model: Optional[MyGPT2] = None,  # Ensure model is MyGPT2 type
        args: Optional[MyTrainingArguments] = None,  # Renamed back to args for consistency with parent
        **kwargs,
    ):
        # Ensure the passed args are of the correct type (MyTrainingArguments)
        if args is not None and not isinstance(args, MyTrainingArguments):
            raise ValueError(
                f"MyTrainer requires an instance of MyTrainingArguments, but received {type(args)}."
                " Ensure you are parsing arguments using HfArgumentParser((MyTrainingArguments,))"
            )

        super().__init__(model=model, args=args, **kwargs)

        # Initialize stage (will be set by callback before first epoch)
        self.current_stage = 1  # Start assuming stage 1
        # Placeholders for stage-specific optimizers/schedulers
        self.optimizer1 = None
        self.lr_scheduler1 = None
        self.optimizer2 = None
        self.lr_scheduler2 = None

        # Need to ensure _setup_stage_parameters is called initially
        self._setup_stage_parameters()

    def _setup_stage_parameters(self):
        """Sets the requires_grad flags based on the current stage."""
        # Ensure model is available before accessing named_parameters
        if not hasattr(self, "model") or self.model is None:
            logger.error("Model not available in _setup_stage_parameters. Skipping setup.")
            return
        # Ensure stage is set
        if self.current_stage is None:
            logger.error("Current stage not set in _setup_stage_parameters. Skipping setup.")
            return

        if self.current_stage == 1:
            logger.info("Setting up for Stage 1: Training Transformer weights.")
            for name, param in self.model.named_parameters():
                if "perm_logits" in name:
                    param.requires_grad_(False)
                else:
                    param.requires_grad_(True)
        elif self.current_stage == 2:
            logger.info("Setting up for Stage 2: Training Permutation matrix P.")
            for name, param in self.model.named_parameters():
                if "perm_logits" in name:
                    param.requires_grad_(True)
                else:
                    param.requires_grad_(False)
        else:
            raise ValueError(f"Invalid stage: {self.current_stage}")

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """Create real optimizers for stages, but use dummy for Trainer's main reference."""
        # Based on previous version that calculates steps per stage
        steps_per_epoch = 0
        try:
            train_dataloader = self.get_train_dataloader()
            if hasattr(train_dataloader, "__len__"):
                steps_per_epoch = len(train_dataloader)
            elif self.args.max_steps > 0 and self.args.num_train_epochs > 0:
                steps_per_epoch = self.args.max_steps // self.args.num_train_epochs
            if steps_per_epoch == 0 and self.args.num_train_epochs > 0:
                steps_per_epoch = num_training_steps // self.args.num_train_epochs
            if steps_per_epoch == 0:
                steps_per_epoch = 1
                logger.warning("Could not determine steps_per_epoch accurately.")
        except Exception as e:
            logger.warning(f"Could not get train dataloader length: {e}. Estimating steps per epoch.")
            if self.args.num_train_epochs > 0:
                steps_per_epoch = num_training_steps // self.args.num_train_epochs
            if steps_per_epoch == 0:
                steps_per_epoch = 1

        total_epochs = int(self.args.num_train_epochs)
        num_steps_stage1 = math.ceil(total_epochs / 2) * steps_per_epoch
        num_steps_stage2 = math.floor(total_epochs / 2) * steps_per_epoch
        if self.args.max_steps > 0 and num_training_steps < (total_epochs * steps_per_epoch):
            num_steps_stage1 = math.ceil(num_training_steps / 2)
            num_steps_stage2 = math.floor(num_training_steps / 2)
        logger.info(f"Trainer estimated steps per stage: S1={num_steps_stage1}, S2={num_steps_stage2}")

        # --- Create Optimizer/Scheduler 1 ---
        # breakpoint()
        transformer_params = [p for n, p in self.model.named_parameters() if "perm_logits" not in n]
        # transformer_params = self.model.named_parameters()
        if transformer_params:
            optimizer_grouped_parameters1 = [{"params": transformer_params, "lr": self.args.transformer_lr}]
            optimizer_cls1, optimizer_kwargs1 = Trainer.get_optimizer_cls_and_kwargs(self.args)
            self.optimizer1 = optimizer_cls1(optimizer_grouped_parameters1, **optimizer_kwargs1)
            self.lr_scheduler1 = get_scheduler(
                name=self.args.stage1_scheduler_type,
                optimizer=self.optimizer1,
                num_warmup_steps=self.args.get_warmup_steps(num_steps_stage1),
                num_training_steps=num_steps_stage1,
            )
            logger.info(
                f"Created optimizer and {self.args.stage1_scheduler_type} scheduler for Stage 1 ({num_steps_stage1} steps)."
            )
        else:
            logger.warning("No parameters found for Stage 1 optimizer.")

        # --- Create Optimizer/Scheduler 2 ---
        permutation_params = [p for n, p in self.model.named_parameters() if "perm_logits" in n]
        if permutation_params:
            optimizer_grouped_parameters2 = [{"params": permutation_params, "lr": self.args.permutation_lr}]
            # Assuming default SGD or allows customization via args if needed later
            # breakpoint()
            self.args.optim = "sgd"
            optimizer_cls2, optimizer_kwargs2 = Trainer.get_optimizer_cls_and_kwargs(self.args)
            # If stage2_optim was defined in args, use it here
            # if self.args.stage2_optim.lower() == "sgd": optimizer_kwargs2["momentum"] = self.args.stage2_momentum
            self.optimizer2 = optimizer_cls2(optimizer_grouped_parameters2, **optimizer_kwargs2)
            self.lr_scheduler2 = get_scheduler(
                name=self.args.stage2_scheduler_type,
                optimizer=self.optimizer2,
                num_warmup_steps=self.args.get_warmup_steps(num_steps_stage2),
                num_training_steps=num_steps_stage2,
            )
            logger.info(
                f"Created optimizer and {self.args.stage2_scheduler_type} scheduler for Stage 2 ({num_steps_stage2} steps)."
            )
        else:
            logger.warning("No parameters found for Stage 2 optimizer.")

        # Set Trainer's main optimizer/scheduler to DUMMY objects
        self.optimizer = DummyOptimizer()
        self.lr_scheduler = DummyScheduler()
        logger.info("Trainer internal optimizer/scheduler set to Dummies. Step logic handles real ones.")

    def compute_loss(self, model, inputs, return_outputs=False):
        """Computes loss based on self.current_stage, adds LR to metrics."""
        # Get labels from inputs before removing them
        labels = inputs.get("labels")
        model_inputs = {k: v for k, v in inputs.items() if k != "labels"}

        outputs = model(**model_inputs, global_step=self.state.global_step, max_steps=self.state.max_steps)

        metrics = {}
        loss = None

        if self.current_stage == 1:
            # Calculate language modeling loss manually if not provided by model
            if hasattr(outputs, "loss") and outputs.loss is not None:
                loss = outputs.loss
            elif labels is not None:
                # Get logits from outputs
                logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
                # attentions = outputs.attentions if hasattr(outputs, "attentions") else None

                # attn_scores = attentions[0] + 1e-9
                # entropy = -torch.sum(attn_scores * torch.log(attn_scores), dim=-1)
                # entropy = entropy[:, :, 20:]
                # avg_entropy = entropy.mean(dim=(1, 2))
                # loss_entropy = avg_entropy.mean()  # Average over batch

                # Shift so that tokens < n predict n
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()

                # Flatten the tokens
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                # loss = loss + loss_entropy
            else:
                raise ValueError("Labels must be provided for stage 1 training")

            if loss is not None:
                metrics["stage1_lm_loss"] = loss.item()
            if hasattr(self, "optimizer1") and self.optimizer1:
                metrics["stage1_lr"] = self.optimizer1.param_groups[0]["lr"]

            # breakpoint()

        elif self.current_stage == 2:
            # Stage 2 logic now uses outputs directly from the Gumbel-Sinkhorn path
            attentions = outputs.attentions

            if attentions is None or len(attentions) == 0:
                logger.warning("Attentions not found in model output for stage 2 loss calculation.")
                loss = torch.tensor(0.0, device=model.device, requires_grad=True)  # Dummy loss
                metrics = {
                    "stage2_entropy_loss": 0.0,
                }
                # Add LR even if loss is dummy
                if hasattr(self, "optimizer2") and self.optimizer2 is not None:
                    metrics["stage2_lr"] = self.optimizer2.param_groups[0]["lr"]
            else:
                layer1_attentions = attentions[0]  
                attn_scores = layer1_attentions + 1e-9

                entropy = -torch.sum(attn_scores * torch.log(attn_scores), dim=-1)
                entropy = entropy[:, :, 20:]  # Select the last token's attention distribution
                avg_entropy = entropy.mean(dim=(1, 2))  # Average over heads and seq_len
                loss = avg_entropy.mean()  # Average over batch
                # breakpoint()
                metrics = {"stage2_entropy_loss": loss.item()}

                # breakpoint()
                if hasattr(self, "optimizer2") and self.optimizer2 is not None:
                    metrics["stage2_lr"] = self.optimizer2.param_groups[0]["lr"]
        else:
            raise ValueError(f"Invalid stage: {self.current_stage}")

        # Log metrics explicitly here
        if self.is_world_process_zero() and metrics:
            self.log(metrics)

        return (loss, outputs) if return_outputs else loss

    def training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
    ) -> torch.Tensor:
        """Override training_step for step-by-step stage switching."""
        model.train()  # Ensure model is in train mode

        # 1. Determine current stage for this step
        step = self.state.global_step
        # self.current_stage = 1 if step % 2 == 0 else 2
        self.current_stage = 1
        if step < 2000:
            self.current_stage = 1
        else:
            self.current_stage = 2

        # 2. Set requires_grad accordingly
        self._setup_stage_parameters()

        # 3. Prepare inputs (needed for compute_loss and potential deepspeed)
        inputs = self._prepare_inputs(inputs)

        # 4. Compute loss (handles stage-specific forward and loss calc)
        # Use context manager for gradient handling (e.g., amp)
        with self.compute_loss_context_manager():
            loss = self.compute_loss(model, inputs)

        # 5. Backward pass
        # Handle gradient accumulation and scaling if applicable
        if self.use_apex:  # Example for apex amp
            from apex import amp

            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.use_cpu_amp:  # Example for native amp
            self.scaler.scale(loss).backward()
        elif self.deepspeed:  # Example for deepspeed
            self.deepspeed.backward(loss)
        else:
            loss.backward()

        # 6. Select Optimizer and Scheduler
        optimizer = None
        scheduler = None
        if self.current_stage == 1 and self.optimizer1 and self.lr_scheduler1:
            optimizer = self.optimizer1
            scheduler = self.lr_scheduler1
            # breakpoint()
            metrics = {"transformer_grad": self.model.lm_head.weight.grad.mean().item()}
            self.log(metrics)
        elif self.current_stage == 2 and self.optimizer2 and self.lr_scheduler2:
            optimizer = self.optimizer2
            scheduler = self.lr_scheduler2
            N_logits = torch.nn.Softmax(dim=0)(self.model.perm_logits.clone().detach())
            N_grad = self.model.perm_logits.grad.clone().detach() if self.model.perm_logits.grad is not None else 0.0
            # breakpoint()
            metrics = {"N_select_grad": N_grad.mean().item()}
            for i in range(N_logits.shape[0]):
                metrics[f"N_select_{i}"] = N_logits[i].item()
            self.log(metrics)
        else:
            logger.error(f"Optimizer/Scheduler not found for stage {self.current_stage} at step {step}")
            # Maybe default to optimizer1 if it exists? Or raise error?
            # Let's just skip step/scheduling if not found for now
            optimizer = None

        # 7. Optimizer Step (handle clipping and scaling)
        if optimizer:
            # Gradient clipping
            if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0:
                if self.use_apex:  # Example for apex amp
                    # Cannot use apex clipping with multiple optimizers easily
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                elif self.use_cpu_amp:  # Example for native amp
                    self.scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]["params"], self.args.max_grad_norm)
                elif self.deepspeed:
                    # Deepspeed handles clipping internally via config
                    pass
                else:
                    torch.nn.utils.clip_grad_norm_(optimizer.param_groups[0]["params"], self.args.max_grad_norm)

            # Optimizer step
            if self.use_cpu_amp:
                self.scaler.step(optimizer)
                self.scaler.update()
            elif self.deepspeed:
                self.deepspeed.step()
            else:
                optimizer.step()

            # Scheduler step
            if scheduler:
                scheduler.step()

        # 8. Zero Gradients for *both* optimizers
        if self.optimizer1:
            self.optimizer1.zero_grad(set_to_none=True)
        if self.optimizer2:
            self.optimizer2.zero_grad(set_to_none=True)

        # Detach and return loss, ensuring it's on CPU if model is on multiple devices
        return loss.detach() / self.args.gradient_accumulation_steps

    # --- Optional: Override save/load state if needed ---
    # Need to handle saving/loading optimizer1/2 and lr_scheduler1/2 states
    # ... (implementation would be complex) ...


class LogPermutationCallback(TrainerCallback):
    """Logs the permutation matrix (softmax of model.P) as a heatmap image to W&B."""

    def on_step_end(
        self, args: TrainingArguments, state: TrainerState, control: TrainerControl, model: MyGPT2, **kwargs
    ):
        """Log the permutation matrix at the end of specific steps in stage 2."""
        # Log only on the main process, during stage 2, and at logging steps
        if state.is_world_process_zero and hasattr(model, "P") and state.global_step > 0:
            current_stage = 1 if state.global_step % 2 == 0 else 2
            if current_stage == 2 and state.global_step % args.logging_steps == 0:
                try:
                    # Get the full P matrix logits and apply softmax
                    p_matrix_logits = model.P.detach().cpu()
                    p_matrix_softmax = torch.softmax(p_matrix_logits, dim=-1).numpy()

                    # Create heatmap plot
                    fig, ax = plt.subplots(figsize=(8, 8))  # Adjust figsize as needed
                    im = ax.imshow(p_matrix_softmax, cmap="viridis", aspect="auto")
                    fig.colorbar(im, ax=ax)
                    ax.set_title(f"Permutation Matrix (Softmax of P) at Step {state.global_step}")
                    ax.set_xlabel("Output Position")
                    ax.set_ylabel("Input Position")

                    # --- Save locally ---
                    save_dir = args.output_dir
                    os.makedirs(save_dir, exist_ok=True)

                    # Save heatmap image
                    img_path = os.path.join(save_dir, f"step_{state.global_step}.png")
                    fig.savefig(img_path)

                    # Optionally, save the raw permutation matrix for further analysis
                    npy_path = os.path.join(save_dir, f"step_{state.global_step}.npy")
                    np.save(npy_path, p_matrix_softmax)

                    # Also keep logging to W&B for convenience (can be removed if undesired)
                    try:
                        wandb.log({"permutation_matrix_heatmap": wandb.Image(fig)}, step=state.global_step)
                    except Exception as e:
                        logger.warning(f"Failed to log heatmap to W&B at step {state.global_step}: {e}")

                    # Close the plot to free memory
                    plt.close(fig)

                except Exception as e:
                    logger.warning(f"Failed to log permutation matrix heatmap at step {state.global_step}: {e}")


# Optional: Add Sinkhorn normalization function if needed
# def sinkhorn(P, n_iters=10):
#     for _ in range(n_iters):
#         P = P / P.sum(dim=1, keepdim=True)
#         P = P / P.sum(dim=0, keepdim=True)
#     return P
