import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"  # to avoid HF Tokenizer warning
import argparse

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

import torch
from torch import nn
from torch.utils.data import DataLoader

from transformers import AutoTokenizer

import wandb
import numpy as np

import constrained_trf_unrolling_language.data as data
from constrained_trf_unrolling_language.models import BERTClassifier, SymmetricUnrolledTransformer


class ConstrainedLanguageExpModule(pl.LightningModule):
    def __init__(self, language_classifier, config):
        super().__init__()
        self.save_hyperparameters(config)

        self.model = language_classifier

        self.is_unrolled_transformer = isinstance(self.model, SymmetricUnrolledTransformer)

        # Set number of layers based on model type
        if self.is_unrolled_transformer:
            self.L = self.model.num_layers
        else:
            self.L = self.model.bert.config.num_hidden_layers

        if self.is_unrolled_transformer:
            num_classes = config.get("num_classes", 2)
            self.classifier_head = nn.Linear(768, num_classes)

        self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

        self.loss_fn = nn.CrossEntropyLoss(reduction="none")

        self.embedding_perturb_std = self.hparams.embedding_perturb_std
        self.ood_test_settings = [0.0, 0.15, 0.3, 0.45, 0.6, 0.75, 0.9, 1.05, 1.2, 1.35, 1.5, 1.65, 1.8, 1.95, 2.1]

        self.LOG_FLAGS = {"sync_dist": True, "on_step": True, "on_epoch": False, "prog_bar": True}

        # Add dual variables and learning rates
        self.constraint_type = self.hparams.constraint_type
        self.dual_lr = self.hparams.dual_lr
        self.resilience_lr = self.hparams.resilience_lr
        self.resilience_coefficient = self.hparams.resilience_coefficient
        self.initial_constraint = torch.tensor([self.hparams.initial_constraint], requires_grad=True)
        self.constraint_alpha = self.hparams.constraint_alpha
        self.automatic_optimization = False  # Turn off automatic optimization

        self.multipliers = nn.Parameter(torch.ones(self.L) * self.hparams.dual_init, requires_grad=True)

    def setup(self, stage=None):
        """
        Additional setup to ensure that the initial constraint is on the same device as the model.
        Required because of PTL's strategy.
        """
        self.initial_constraint = self.initial_constraint.to(self.device)

    def forward(self, tokens, embedding_perturb_std=0.0, output_intermediates=True):
        """Forward pass with optional embedding perturbation and intermediate outputs."""
        if self.is_unrolled_transformer:
            # For unrolled transformer, we need to handle the output differently
            transformer_outputs = self.model(
                tokens,
                embedding_perturb_std=embedding_perturb_std,
                output_intermediates=output_intermediates,
            )

            if output_intermediates:
                # Process each layer's output through the classifier head
                layer_outputs = []
                for layer_idx in range(self.L + 1):  # +1 because includes input embedding
                    # Mean pooling over sequence length
                    pooled_output = transformer_outputs[layer_idx].mean(dim=1)
                    # Apply classifier head
                    logits = self.classifier_head(pooled_output)
                    layer_outputs.append(logits)
                return torch.stack(layer_outputs, dim=-1)
            else:
                # Process only the final layer
                pooled_output = transformer_outputs.mean(dim=1)
                return self.classifier_head(pooled_output)
        else:
            # Original BERT model forward pass
            model_output = self.model(
                tokens,
                embedding_perturb_std=embedding_perturb_std,
                output_intermediates=output_intermediates,
            )
            if output_intermediates:
                return torch.stack(model_output, dim=-1)
            else:
                return model_output

    def compute_metrics(self, all_layer_logits, labels):
        """
        Computes metrics for all layers
        """
        batch_size = all_layer_logits.size(0)

        if labels.dim() > 1:
            labels = labels.squeeze(-1)

        labels = labels.long()

        # Initialize metrics for all layers
        loss_all_layers = torch.zeros(self.L, device=self.device)
        accuracy_all_layers = torch.zeros(self.L, device=self.device)

        # Compute metrics for each layer
        for layer_idx in range(self.L):
            layer_logits = all_layer_logits[:, :, layer_idx]

            loss_all_layers[layer_idx] = self.loss_fn(layer_logits, labels).mean()
            preds = layer_logits.argmax(dim=1)
            accuracy_all_layers[layer_idx] = (preds == labels).float().mean()

        return loss_all_layers, accuracy_all_layers

    def _constrained_optimization_step(self, loss_all_layers, loss_last_layer):
        """Perform constrained optimization using primal-dual method.

        Args:
            loss_all_layers: Tensor of shape [num_layers] containing losses for each layer
            loss_last_layer: Loss value from the final layer

        Returns:
            float: The final loss value (loss_plus_penalty)
        """
        # Get optimizers
        primal_optimizer, dual_optimizer = self.optimizers()

        # Primal step
        primal_optimizer.zero_grad()

        # Append initial constraint to loss_all_layers
        loss_all_layers = torch.cat((self.initial_constraint, loss_all_layers), dim=0)

        constraint_value = (loss_all_layers[1:] - self.constraint_alpha * loss_all_layers[:-1]).squeeze()
        num_constraints_violated = (constraint_value > 0).sum()
        num_nondecreasing_layers = (loss_all_layers[1:] >= loss_all_layers[:-1]).sum()
        # Lagrangian loss
        coeff = 1 / (2 * self.resilience_coefficient)
        primal_loss = loss_last_layer + self.multipliers.dot(constraint_value) - coeff * self.multipliers.norm(p=2) ** 2

        self.manual_backward(primal_loss, retain_graph=True)
        d_copt = constraint_value.detach()
        primal_optimizer.step()

        # Dual step
        dual_optimizer.zero_grad()
        penalty = self.multipliers.dot(d_copt)
        dual_loss = penalty - coeff * self.multipliers.norm(p=2) ** 2
        self.manual_backward(-dual_loss)
        dual_optimizer.step()

        self.multipliers.data = self.multipliers.data.clamp(0.0)

        self.log("num_constraints_violated", num_constraints_violated, **self.LOG_FLAGS)
        self.log("num_nondecreasing_layers", num_nondecreasing_layers, **self.LOG_FLAGS)
        self.log("primal_loss", primal_loss, **self.LOG_FLAGS)
        self.log("penalty", penalty, **self.LOG_FLAGS)

        return primal_loss

    def training_step(self, batch, batch_idx):
        primal_optimizer, _ = self.optimizers()

        # Process batch data
        text, labels = batch["text"], batch["label"]

        # Convert labels to proper format if needed (always as class indices)
        if labels.dim() > 1:
            labels = labels.squeeze(-1)
        labels = labels.long().to(self.device)

        # Tokenize text
        tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)

        # Forward pass
        all_layer_logits = self(
            tokens,
            embedding_perturb_std=self.embedding_perturb_std,
            output_intermediates=True,
        )

        # Compute metrics
        loss_all_layers, accuracy_all_layers = self.compute_metrics(all_layer_logits, labels)
        loss_last_layer = loss_all_layers[-1]

        if self.constraint_type != "none":
            loss = self._constrained_optimization_step(loss_all_layers, loss_last_layer)
        else:
            loss = loss_last_layer
            primal_optimizer.zero_grad()
            self.manual_backward(loss)
            primal_optimizer.step()

        ### Logging
        accuracy_last_layer = accuracy_all_layers[-1]
        for i in range(self.L):
            log_flags = {**self.LOG_FLAGS, "prog_bar": False}
            self.log(f"train_loss_layer_{i}", loss_all_layers[i], **log_flags)
            self.log(f"train_acc_layer_{i}", accuracy_all_layers[i], **log_flags)
        self.log("train_loss", loss_last_layer, **self.LOG_FLAGS)
        self.log("train_acc", accuracy_last_layer, **self.LOG_FLAGS)

        for i, multiplier in enumerate(self.multipliers):
            log_flags = {**self.LOG_FLAGS, "prog_bar": False}
            self.log(f"multipliers_{i}", multiplier, **log_flags)
        return loss

    def _eval_step(self, batch, batch_idx, stage="val", test_ood_noise=False):
        """Common evaluation step for validation and testing.
        stage: "val" or "test"
        """
        eval_log_flags = {"sync_dist": True, "on_step": False, "on_epoch": True}
        text, labels = batch["text"], batch["label"].reshape(-1, 1).float()
        tokens = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)

        if test_ood_noise:
            embedding_perturbation_stds = self.ood_test_settings
            ood_results = []  # list to store (std, accuracy)
        else:
            embedding_perturbation_stds = [self.embedding_perturb_std]

        # Initialize variables to store results for in-distribution noise level
        in_distribution_results = None
        for embedding_perturbation_std in embedding_perturbation_stds:
            suffix = f"_p{embedding_perturbation_std:.2f}" if test_ood_noise else ""
            is_in_distribution = abs(embedding_perturbation_std - self.embedding_perturb_std) < 1e-6
            all_layer_logits = self(tokens, embedding_perturb_std=embedding_perturbation_std, output_intermediates=True)

            loss_all_layers, accuracy_all_layers = self.compute_metrics(all_layer_logits, labels)
            if is_in_distribution:
                # Store results for in-distribution noise level if found
                in_distribution_results = {
                    f"{stage}_loss": loss_all_layers[-1],
                    f"{stage}_acc": accuracy_all_layers[-1],
                }
                # and log layerwise results
                for i in range(self.L):
                    log_flags = {**eval_log_flags, "prog_bar": False, "add_dataloader_idx": False}
                    self.log(f"{stage}_loss_layer_{i}", loss_all_layers[i], **log_flags)
                    self.log(f"{stage}_acc_layer_{i}", accuracy_all_layers[i], **log_flags)

            self.log(f"{stage}_loss{suffix}", loss_all_layers[-1], **eval_log_flags)
            self.log(f"{stage}_acc{suffix}", accuracy_all_layers[-1], **eval_log_flags)

            if test_ood_noise:
                # Collect (noise level, accuracy) for ood AUC computation
                acc_val = accuracy_all_layers[
                    -1
                ].item()  # if hasattr(accuracy_all_layers[-1], "item") else accuracy_all_layers[-1]
                ood_results.append((embedding_perturbation_std, acc_val))

        # Compute and log area under the curve for OOD noise if needed.
        if test_ood_noise:
            # Ensure results are sorted by noise level
            ood_results.sort(key=lambda x: x[0])
            noise_levels, accuracies = zip(*ood_results)
            auc = np.trapz(accuracies, x=noise_levels)
            metric_name = f"{stage}_ood_auc"
            self.log(metric_name, auc, **eval_log_flags)

        return in_distribution_results

    def validation_step(self, batch, batch_idx):
        # This treats all validation data the same regardless of which loader it came from
        return self._eval_step(batch, batch_idx, stage="val", test_ood_noise=False)

    def test_step(self, batch, batch_idx):
        # On test, run a validation on the full dataset to get OOD results (needed for OOD AUC computation//hp sweep)
        self._eval_step(batch, batch_idx, stage="val", test_ood_noise=True)
        return self._eval_step(batch, batch_idx, stage="test", test_ood_noise=True)

    def configure_optimizers(self):
        primal_optimizer = torch.optim.Adam(
            [p for n, p in self.named_parameters() if "multipliers" not in n],
            lr=self.hparams.lr,
        )

        dual_optimizer = torch.optim.Adam([self.multipliers], lr=self.dual_lr)
        return [primal_optimizer, dual_optimizer]


def train(config):
    from datetime import datetime

    experiment_tag = config["experiment_tag"]
    if config["experiment_name"] is None:
        experiment_name = f"{experiment_tag}_{config['dataset']}_ct{config['constraint_type']}_ca{config['constraint_alpha']}_pert{config['embedding_perturb_std']}"
        config["experiment_name"] = experiment_name

    wandb_logger = WandbLogger(
        project=config["wandb_project"],
        entity=config["wandb_entity"],
        name=config["experiment_name"],
        tags=[experiment_tag],
        log_model='all' if config["save_checkpoints"] else None,
    )

    data_module: data.ConstrainedLanguageDataModule = None
    if config["dataset"] == "imdb":
        data_module = data.IMDBDataModule(batch_size=config["batch_size"], num_workers=config["num_workers"])
    elif config["dataset"] == "mnli":
        data_module = data.MNLIDataModule(batch_size=config["batch_size"], num_workers=config["num_workers"])
    elif config["dataset"] == "huggingface":
        # For generic Huggingface datasets specified through config
        data_module = data.HuggingfaceDataModule(
            dataset_name=config["dataset_name"],
            num_classes=config["dataset_num_classes"],
            batch_size=config["batch_size"],
            num_workers=config["num_workers"],
            val_split=config.get("val_split", 0.1),
            seed=config.get("seed", 42),
            dataset_config=config.get("dataset_config", None),
            text_fields=config.get("text_fields", None),
            label_field=config.get("label_field", "label"),
            test_split=config.get("test_split", "test"),
        )
    else:
        raise ValueError(f"Unsupported dataset: {config['dataset']}")

    if config["constraint_alpha"] > 1:
        raise ValueError("Constraint alpha must be <= 1")

    # Add num_classes to config
    config["num_classes"] = data_module.num_classes

    if config["model"] == "unrolled_transformer":
        bert_model = SymmetricUnrolledTransformer(
            embedding_dim=768,
            hidden_dim=768,
            num_layers=config["num_layers"],
            diff_d=config.get("diff_d", True),
            alpha=config.get("model_alpha", 1.0),
        )
    elif config["model"] == "distilbert-base-uncased":
        bert_model = BERTClassifier(num_classes=data_module.num_classes)
    else:
        raise ValueError(f"Unsupported model: {config['model']}")

    model = ConstrainedLanguageExpModule(language_classifier=bert_model, config=config)

    wandb_logger.watch(model, log="all", log_freq=100)

    callbacks = []

    if config["save_checkpoints"]:
        wandb_logger.experiment.log_code('.')
        wandb_run_id = wandb_logger.experiment.id
        dataset = config["dataset"]
        date = datetime.now()
        checkpoint_dir = f"checkpoints/{date:%Y%m%d-%H_%M}_{dataset}_{experiment_tag}_{wandb_run_id}"
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",
            dirpath=checkpoint_dir,
            filename=f"best",
            save_top_k=1,
            mode="min",
        )
        callbacks.append(checkpoint_callback)
    else:
        print("WARNING: Did not pass --save_checkpoints flag. Checkpoints will not be saved")

    if config["debug_mode"]:
        # Enable anomaly detection in debug mode
        torch.autograd.set_detect_anomaly(True)

    debug_arguments = {"limit_train_batches": 5, "limit_val_batches": 5, "limit_test_batches": 5}
    trainer = pl.Trainer(
        strategy=config["strategy"],
        accelerator=config["accelerator"],
        devices=config["devices"],
        max_epochs=config["num_epochs"],
        check_val_every_n_epoch=1,
        logger=wandb_logger,
        callbacks=callbacks,
        log_every_n_steps=config["log_every_n_steps"],
        detect_anomaly=config["debug_mode"],
        **(debug_arguments if config["debug_mode"] else {}),
    )

    # Initial constrained training
    trainer.fit(model, data_module)

    # Final layer finetuning (if enabled and applicable)
    if config["last_layer_ft_epochs"] > 0 and config["constraint_type"] != "none":
        print(f"Starting final layer finetuning for {config['last_layer_ft_epochs']} epochs")

        # Different handling based on model type
        if model.is_unrolled_transformer:
            # For unrolled transformer, just keep the last layer trainable
            for name, param in model.model.named_parameters():
                # First set everything to not require gradients
                param.requires_grad = False

                # Only keep last layer trainable
                if f"D1.{model.L-1}" in name or f"D2.{model.L-1}" in name:
                    param.requires_grad = True
                    print(f"Keeping parameter trainable: {name}")

            # Keep classifier head trainable
            for name, param in model.named_parameters():
                if "classifier_head" in name:
                    param.requires_grad = True
                    print(f"Keeping parameter trainable: {name}")
        else:
            # Get the number of transformer layers dynamically
            num_layers = len(model.model.bert.transformer.layer)
            last_layer_idx = num_layers - 1
            print(f"Model has {num_layers} transformer layers. Unfreezing layer {last_layer_idx}.")

            # Freeze all transformer blocks except the last one, keep classifier_head trainable
            for name, param in model.model.named_parameters():
                # First set everything to not require gradients
                param.requires_grad = False

                # Then enable gradients for the last transformer block
                if f"transformer.layer.{last_layer_idx}" in name:  # or "classifier_head" in name:
                    param.requires_grad = True
                    print(f"Keeping parameter trainable: {name}")

            # Freeze classifier head too
            for name, param in model.named_parameters():
                if "classifier_head" in name:
                    param.requires_grad = False

        # Set the dual optimizer's learning rate to zero to freeze duals without breaking code
        _, dual_optimizer = model.optimizers()
        for param_group in dual_optimizer.param_groups:
            param_group['lr'] = 0.0

        # Switch to unconstrained training
        model.hparams.constraint_type = "none"
        model.constraint_type = "none"

        print(f"Starting final layer finetuning for {config['last_layer_ft_epochs']} epochs")

        ft_trainer = pl.Trainer(
            strategy=config["strategy"],
            accelerator=config["accelerator"],
            devices=config["devices"],
            max_epochs=config["last_layer_ft_epochs"],
            logger=wandb_logger,
            callbacks=callbacks,
            log_every_n_steps=config["log_every_n_steps"],
            **(debug_arguments if config["debug_mode"] else {}),
        )
        ft_trainer.fit(model, data_module)

    trainer.test(model, data_module)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train a constrained language model")

    # Training parameters
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=5, help="Number of epochs")
    parser.add_argument("--num_workers", type=int, default=16, help="Number of workers")
    parser.add_argument("--devices", type=int, default=1, help="Number of GPU devices to use")
    parser.add_argument("--accelerator", type=str, default="gpu", help="Accelerator to use for training")
    parser.add_argument(
        "--strategy", type=str, default="ddp_find_unused_parameters_true", help="Strategy to use for training"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="distilbert-base-uncased",
        choices=["distilbert-base-uncased", "unrolled_transformer"],
        help="Model name or type to use",
    )
    parser.add_argument("--dataset", type=str, default="imdb", help="Dataset name")
    parser.add_argument(
        "--embedding_perturb_std",
        type=float,
        default=0.01,
        help="Standard deviation for embedding perturbation",
    )

    # Constraint parameters
    parser.add_argument(
        "--constraint_type",
        type=str,
        choices=["none", "cross_entropy"],
        help="Type of constraint to apply",
    )
    parser.add_argument("--constraint_alpha", type=float, default=0.9, help="Constraint alpha parameter")
    parser.add_argument("--dual_lr", type=float, default=0.0003, help="")
    parser.add_argument("--resilience_lr", type=float, default=0.0001, help="Learning rate for resilience")
    parser.add_argument(
        "--resilience_coefficient", type=float, default=1.0, help="Coefficient for the resilience penalty term."
    )
    parser.add_argument("--dual_init", type=float, default=0.0, help="Initial value for dual variables")
    parser.add_argument("--initial_constraint", type=float, default=0.9, help="Initial constraint value")
    # Logging parameters
    parser.add_argument(
        "--wandb_project",
        type=str,
        default="project_name",
        help="Weights & Biases project name",
    )
    parser.add_argument(
        "--wandb_entity",
        type=str,
        default="anonymous",
        help="Weights & Biases entity name",
    )
    parser.add_argument("--log_every_n_steps", type=int, default=5, help="Log every N steps")
    parser.add_argument(
        "--experiment_tag",
        type=str,
        default="experiment",
        help="Tag to identify this experiment run",
    )
    parser.add_argument(
        "--experiment_name",
        type=str,
        default=None,
        help="Optional experiment name for wandb",
    )

    # Reproducibility parameters
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Seed for reproducibility",
    )

    # Debug parameters
    parser.add_argument(
        "--save_checkpoints",
        action="store_true",
        help="Whether to save model checkpoints and log them to wandb",
    )

    parser.add_argument(
        "--debug_mode",
        action="store_true",
        help="Debug run only 1% of the data",
    )

    # Final layer finetuning parameter
    parser.add_argument(
        "--last_layer_ft_epochs",
        type=int,
        default=0,
        help="Number of epochs for final layer finetuning (if > 0, finetune only the last layer after constrained training)",
    )

    # Add unrolled transformer specific arguments
    parser.add_argument(
        "--diff_d",
        type=bool,
        default=True,
        help="Whether to use different D1 and D2 matrices for each layer in unrolled transformer",
    )
    parser.add_argument(
        "--model_alpha", type=float, default=1.0, help="Scaling factor for the FFN in unrolled transformer"
    )

    # Add new argument for num_layers
    parser.add_argument(
        "--num_layers",
        type=int,
        default=6,
        help="Number of layers in the unrolled transformer model",
    )

    # Parameters for the HuggingfaceDataModule
    parser.add_argument("--dataset_name", type=str, help="HuggingFace dataset name (for huggingface dataset type)")
    parser.add_argument("--dataset_config", type=str, help="HuggingFace dataset config (for huggingface dataset type)")
    parser.add_argument("--dataset_num_classes", type=int, help="Number of classes (for huggingface dataset type)")
    parser.add_argument("--text_fields", nargs="+", help="Text fields to combine (for huggingface dataset type)")
    parser.add_argument(
        "--label_field", type=str, default="label", help="Label field name (for huggingface dataset type)"
    )
    parser.add_argument("--test_split", type=str, default="test", help="Test split name (for huggingface dataset type)")
    parser.add_argument("--val_split", type=float, default=0.1, help="Validation split ratio (if needed)")

    args = parser.parse_args()
    config = vars(args)

    if config["model"] == "distilbert-base-uncased":
        # Distilbert has 12 fixed layers.
        config["num_layers"] = 12

    print("Starting training with config:")
    print(config)
    train(config)
