from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple, Union

import einops
import torch
import torch as t
from einops import einsum, rearrange, reduce
from jaxtyping import Float
from torch import Tensor, nn
from torch.nn import functional as F
from tqdm import tqdm

from adversarial_superposition.constants import DEVICE
from adversarial_superposition.toy_models.utils.pgd_attack import (
    pgd_l2_adv,
    pgd_linf_adv,
)


@dataclass
class ClassificationConfig:
    """Configuration for the classification model."""

    n_classes: int = 5
    n_features: int = 0
    n_instances: int = 0
    n_hidden: int = 0
    data_range: Tuple[int, int] = (0, 1)
    n_correlated_feature_pairs: int = 0
    feature_generation_mode: str = "mixed"
    class_correlation_matrix: Optional[Tensor] = None
    class_correlation_scale: float = 1.0
    cycle_base_prob: float = 0.5
    cycle_amplitude: float = 0.4
    cycle_sparsity: float = 0.0

    def __post_init__(self):
        """Initialize derived attributes."""
        self.n_features_per_class = self.n_features // self.n_classes
        if self.data_range not in [(0, 1), (-1, 1)]:
            raise ValueError("data_range must be either (0, 1) or (-1, 1)")
        if (
            self.feature_generation_mode == "class_correlated"
            and self.class_correlation_matrix is not None
        ):
            if self.class_correlation_matrix.shape != (self.n_classes, self.n_classes):
                raise ValueError(
                    f"class_correlation_matrix must have shape ({self.n_classes}, {self.n_classes})"
                )
        if self.feature_generation_mode not in ["mixed", "class_correlated"]:
            raise ValueError(
                "feature_generation_mode must be either 'mixed' or 'class_correlated'"
            )


def constant_lr(*_):
    return 1.0


class ClassificationModel(nn.Module):
    """A neural network model for classification tasks with support for various feature correlations."""

    def __init__(
        self,
        cfg: ClassificationConfig,
        feature_probability: Optional[Union[float, Tensor]] = None,
        class_importance: Optional[Union[float, Tensor]] = None,
        feature_importance: Optional[Union[float, Tensor]] = None,
        privileged: bool = False,
        privileged_out: bool = False,
        bias: bool = False,
        projection_bias: bool = False,
        loss_fn: str = "ce",
        device: str = DEVICE,
    ) -> None:
        super().__init__()
        self.cfg = cfg
        self.loss_fn = loss_fn
        self.privileged = privileged
        self.privileged_out = privileged_out
        self.bias = bias
        self.projection_bias = projection_bias

        if feature_probability is None:
            feature_probability = t.ones(())
        if isinstance(feature_probability, float):
            feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).to(DEVICE)

        if class_importance is not None and feature_importance is not None:
            raise ValueError("Can only have either class or feature importance")

        self.class_importance = self._init_importance(class_importance, cfg.n_classes)
        self.feature_importance = self._init_importance(
            feature_importance, cfg.n_features_per_class
        )

        self._init_parameters()
        self.to(DEVICE)

    def _init_importance(
        self, importance: Optional[Union[float, Tensor]], size: int
    ) -> Tensor:
        """Initialize importance weights."""
        if importance is None:
            importance = t.ones((size))
        if isinstance(importance, float):
            importance = t.tensor(importance)
        return importance.to(DEVICE)

    def _init_parameters(self) -> None:
        """Initialize model parameters."""
        # Hidden layer weights
        self.W = nn.Parameter(
            nn.init.xavier_normal_(
                t.empty((self.cfg.n_instances, self.cfg.n_hidden, self.cfg.n_features))
            )
        )
        if self.bias:
            self.b_enc = nn.Parameter(
                t.zeros((self.cfg.n_instances, self.cfg.n_hidden))
            )
        # Projection layer - only initialize (no autoencoder mode)
        self.W_projection = nn.Parameter(
            nn.init.xavier_normal_(
                t.empty((self.cfg.n_instances, self.cfg.n_classes, self.cfg.n_hidden))
            )
        )
        if self.projection_bias:
            self.b_projection = nn.Parameter(
                t.zeros((self.cfg.n_instances, self.cfg.n_classes))
            )

    def forward(
        self,
        features: Float[Tensor, "... instances features"],
        instance_idx: Optional[int] = None,
    ) -> Float[Tensor, "... instances classes"]:
        """Forward pass through the model.

        Args:
            features: Input features
            instance_idx: Optional instance index to process

        Returns:
            Classification logits
        """
        features = features.to(DEVICE)

        if instance_idx is not None:
            return self._forward_single_instance(features, instance_idx)
        return self._forward_batch(features)

    def _forward_single_instance(
        self, features: Float[Tensor, "... features"], instance_idx: int
    ) -> Float[Tensor, "... classes"]:
        """Forward pass for a single instance."""
        out = einsum(
            features,
            self.W[instance_idx],
            "... features, hidden features -> ... hidden",
        )
        if self.bias:
            out = out + self.b_enc[instance_idx]
        activations = F.relu(out) if self.privileged else out
        # Only classification mode
        classification = einsum(
            activations,
            self.W_projection[instance_idx],
            "... features, n_classes features -> ... n_classes",
        )
        if self.projection_bias:
            classification = classification + self.b_projection[instance_idx]
        if self.privileged_out:
            classification = F.relu(classification)
        return classification

    def _forward_batch(
        self, features: Float[Tensor, "... instances features"]
    ) -> Float[Tensor, "... instances classes"]:
        """Forward pass for a batch of instances."""
        out = einsum(
            features,
            self.W,
            "... instances features, instances hidden features -> ... instances hidden",
        )
        if self.bias:
            out = out + self.b_enc
        activations = F.relu(out) if self.privileged else out
        # Only classification mode
        classification = einsum(
            activations,
            self.W_projection,
            "... instances features, instances n_classes features -> ... instances n_classes",
        )
        if self.projection_bias:
            classification = classification + self.b_projection
        if self.privileged_out:
            classification = F.relu(classification)
        return classification

    def generate_correlated_features_between_classes(
        self, batch_size
    ) -> Float[Tensor, "batch_size instances features"]:
        n_features_to_generate = 2 * self.cfg.n_correlated_feature_pairs
        if n_features_to_generate == 0:
            return torch.empty((batch_size, self.cfg.n_instances, 0), device=DEVICE)

        feat = t.rand(
            (batch_size, self.cfg.n_instances, n_features_to_generate),
            device=self.W.device,
        )
        feat_set_seeds = t.rand(
            (batch_size, self.cfg.n_instances, self.cfg.n_correlated_feature_pairs),
            device=self.W.device,
        )
        feat_set_is_present = (
            feat_set_seeds <= self.feature_probability[:, [0]]
        )  # self.feature_probability[:, [0]] is just the scalar pfeature probability for ecah instance
        feat_is_present = einops.repeat(
            feat_set_is_present,
            "batch instances features -> batch instances (pair features)",
            pair=2,
        )

        features = t.where(feat_is_present, feat, 0.0).to(DEVICE)
        return features.to(DEVICE)

    def generate_uncorrelated_features(
        self, batch_size: int, n_uncorrelated: int
    ) -> Float[Tensor, "batch instances features"]:
        """Generate uncorrelated features."""
        # Generate features in [0,1] range
        feat = t.rand((batch_size, self.cfg.n_instances, n_uncorrelated), device=DEVICE)

        # Apply data range transformation first if necessary
        if self.cfg.data_range == (-1, 1):
            # Transform from [0,1] to [-1,1]
            features = 2 * feat - 1
        else:
            features = feat

        # Generate seeds for masking
        feat_seeds = t.rand(
            (batch_size, self.cfg.n_instances, n_uncorrelated), device=DEVICE
        )

        # Apply masking: features are set to 0 based on feature_probability
        features = t.where(feat_seeds <= self.feature_probability, features, 0.0)

        # If a row is all 0s, artificially create 1 feature to a positive value
        mask = (features == 0).all(dim=2).to(DEVICE)
        if mask.any():
            # Generate random indices for k dimension
            k_indices = torch.randint(0, features.shape[2], (mask.sum(),)).to(DEVICE)
            # Get indices where mask is True
            n_indices, m_indices = torch.where(mask)
            n_indices, m_indices = n_indices.to(DEVICE), m_indices.to(DEVICE)
            # Assign a random value within the correct data range
            rand_val = torch.rand(mask.sum(), device=DEVICE)
            if self.cfg.data_range == (-1, 1):
                rand_val = 2 * rand_val - 1  # random value in [-1, 1]
            features[n_indices, m_indices, k_indices] = rand_val

        return features

    def generate_class_correlated_features(
        self, batch_size: int
    ) -> Float[Tensor, "batch instances features"]:
        """Generate features based on correlated latent class scores using MPS-friendly methods."""
        assert (
            self.cfg.class_correlation_matrix is not None
        ), "Correlation matrix required"
        assert (
            self.cfg.n_features % self.cfg.n_classes == 0
        ), "n_features must be divisible by n_classes"

        n_features_per_class = self.cfg.n_features // self.cfg.n_classes
        batch_shape = (batch_size, self.cfg.n_instances)
        target_device = DEVICE  # Store the target device (MPS)

        # --- Perform Cholesky and initial sampling on CPU ---
        cpu_device = torch.device("cpu")
        corr_matrix_cpu = self.cfg.class_correlation_matrix.to(cpu_device)

        # 1. Perform Cholesky decomposition on CPU
        try:
            # L is lower triangular, corr_matrix = L @ L.T
            L_cpu = torch.linalg.cholesky(corr_matrix_cpu)
        except Exception as e:
            print(f"Cholesky decomposition failed on CPU: {e}")
            print("Ensure the correlation matrix is positive semi-definite.")
            raise e

        # 2. Sample standard normal variables Z on CPU
        standard_normal_samples_cpu = torch.randn(
            batch_shape + (self.cfg.n_classes,), device=cpu_device
        )

        # 3. Transform Z using L to get correlated latent scores on CPU
        latent_scores_cpu = torch.einsum(
            "jc, bic -> bij", L_cpu, standard_normal_samples_cpu
        )

        # 4. Move latent scores back to the target device (MPS)
        latent_scores = latent_scores_cpu.to(target_device)
        # --- End CPU-specific section ---

        # --- Steps 5-8 run on the target device (MPS) ---

        # 5. Generate features based on latent scores (on target device)
        latent_influence = (
            einsum(
                latent_scores,
                torch.eye(self.cfg.n_classes, device=target_device).repeat_interleave(
                    n_features_per_class, dim=1
                ),
                "b i c, c f -> b i f",
            )
            * self.cfg.class_correlation_scale
        )

        # 6. Generate base noise (on target device)
        base_noise = torch.randn(
            batch_size, self.cfg.n_instances, self.cfg.n_features, device=target_device
        )

        # Combine noise and latent influence
        features_raw = base_noise + latent_influence

        # 7. Clamp/Scale features to the desired data range (on target device)
        min_val, max_val = self.cfg.data_range
        if min_val == 0 and max_val == 1:
            features_scaled = torch.sigmoid(features_raw)
        elif min_val == -1 and max_val == 1:
            features_scaled = torch.tanh(features_raw)
        else:
            features_scaled = torch.clamp(features_raw, min_val, max_val)

        # 8. Apply feature probability masking (on target device)
        # Ensure feature_probability is on the correct device if it's a tensor
        feature_probability = self.feature_probability.to(target_device)
        feat_seeds = torch.rand_like(features_scaled, device=target_device)
        features_masked = torch.where(
            feat_seeds <= feature_probability, features_scaled, 0.0
        )

        # 9. Safety check: Ensure no sample row is all zeros *after masking* (on target device)
        mask = (features_masked == 0).all(dim=2)
        if mask.any():
            n_indices, m_indices = torch.where(mask)
            num_zero_samples = len(n_indices)

            k_indices = torch.randint(
                0, self.cfg.n_features, (num_zero_samples,), device=target_device
            )
            rand_vals = torch.rand(num_zero_samples, device=target_device)
            if self.cfg.data_range == (-1, 1):
                rand_vals = 2 * rand_vals - 1
            elif self.cfg.data_range == (0, 1):
                pass
            else:
                rand_vals = rand_vals * (max_val - min_val) + min_val

            features_masked[n_indices, m_indices, k_indices] = rand_vals

        return features_masked  # Already on target_device

    def _assign_class(
        self,
        batch: Float[Tensor, "batch instances features"],
        single_index: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        """Assign class labels based on feature sums."""
        if self.cfg.n_features % self.cfg.n_classes != 0:
            raise ValueError(
                "There needs to be an equal amount of features in each class"
            )

        # Reshape and sum features per class
        rearrange_str = (
            "batch (classes features) -> batch classes features"
            if single_index
            else "batch instances (classes features) -> batch instances classes features"
        )
        reduce_str = (
            "batch classes features -> batch classes"
            if single_index
            else "batch instances classes features -> batch instances classes"
        )

        reshaped = rearrange(batch, rearrange_str, classes=self.cfg.n_classes)
        class_sums = reduce(reshaped, reduce_str, "sum")
        labels = t.argmax(class_sums, dim=-1)

        return labels, class_sums

    def generate_batch(
        self, batch_size: int, instance_idx: Optional[int] = None
    ) -> Tuple[Tensor, Tensor]:
        """Generate a batch of data and labels, potentially mixing feature types."""
        if self.cfg.feature_generation_mode == "class_correlated":
            # Generate features based on cyclic phase correlations across classes
            assert self.cfg.n_features == self.cfg.n_classes, (
                f"For 'class_correlated' mode using cycle generation, n_features ({self.cfg.n_features})"
                f" must equal n_classes ({self.cfg.n_classes})"
            )

            # Get cycle parameters from config or use defaults
            base_prob = getattr(self.cfg, "cycle_base_prob", 0.5)
            amplitude = getattr(self.cfg, "cycle_amplitude", 0.4)

            # Generate random phases for each sample (batch_size, n_instances)
            phase = (
                torch.rand(batch_size, self.cfg.n_instances, 1, device=DEVICE)
                * 2
                * torch.pi
            )

            # Initialize feature tensor
            batch = torch.zeros(
                batch_size, self.cfg.n_instances, self.cfg.n_classes, device=DEVICE
            )

            # Generate features based on phase
            for k in range(self.cfg.n_classes):
                # Calculate probability p for class k based on phase
                p_k = base_prob + amplitude * torch.cos(
                    phase + 2 * torch.pi * k / self.cfg.n_classes
                )
                p_k = torch.clip(p_k, 0, 1)  # Ensure probability is valid

                # Sample uniformly from [0, p_k] instead of [0, 1]
                is_active_mask = torch.rand_like(p_k) < p_k

                # Sample uniformly from [0, p_k] instead of [0, 1]
                scaled_values = torch.rand_like(p_k) * p_k
                batch[:, :, k] = torch.where(
                    is_active_mask.squeeze(-1), scaled_values.squeeze(-1), 0.0
                )

            sparsity_mask = (
                torch.rand(
                    batch_size, self.cfg.n_instances, self.cfg.n_classes, device=DEVICE
                )
                >= self.cfg.cycle_sparsity
            )
            batch *= sparsity_mask

            # Safety check: Ensure no sample is all zeros
            all_zeros_mask = (batch == 0).all(dim=2)
            if all_zeros_mask.any():
                n_indices, m_indices = torch.where(all_zeros_mask)
                num_zero_samples = len(n_indices)

                # Randomly choose one feature index k to activate for each zero sample
                k_indices = torch.randint(
                    0, self.cfg.n_classes, (num_zero_samples,), device=DEVICE
                )
                # Activate the chosen feature with a random value in [0,1]
                random_values = torch.rand(num_zero_samples, device=DEVICE)
                batch[n_indices, m_indices, k_indices] = random_values

        elif self.cfg.feature_generation_mode == "mixed":
            data = []
            n_features_generated = 0
            # Generate correlated features if specified
            if self.cfg.n_correlated_feature_pairs > 0:
                corr_features = self.generate_correlated_features_between_classes(
                    batch_size
                )
                data.append(corr_features)
                n_features_generated += corr_features.shape[-1]
            # Generate uncorrelated features for the remainder
            n_uncorrelated = self.cfg.n_features - n_features_generated
            if n_uncorrelated < 0:
                raise ValueError(
                    f"More features specified in correlated pairs ({n_features_generated}) than total features ({self.cfg.n_features})"
                )
            if n_uncorrelated > 0:
                uncorr_features = self.generate_uncorrelated_features(
                    batch_size, n_uncorrelated
                )
                data.append(uncorr_features)
                n_features_generated += uncorr_features.shape[-1]
            assert (
                n_features_generated == self.cfg.n_features
            ), f"Generated {n_features_generated} features, but expected {self.cfg.n_features}"
            batch = t.cat(data, dim=-1)
            # Global safety check: Ensure no sample is all zeros (Vectorized) - for mixed mode
            final_mask = (batch == 0).all(dim=2)
            if final_mask.any():
                n_indices, m_indices = torch.where(final_mask)
                num_zero_samples = len(n_indices)
                n_corr = 2 * self.cfg.n_correlated_feature_pairs
                n_features = self.cfg.n_features
                N = self.cfg.n_correlated_feature_pairs
                k_indices = torch.randint(
                    0, n_features, (num_zero_samples,), device=DEVICE
                )
                rand_vals = torch.rand(num_zero_samples, device=DEVICE)
                if self.cfg.data_range == (-1, 1):
                    rand_vals = 2 * rand_vals - 1
                batch[n_indices, m_indices, k_indices] = rand_vals
                is_correlated_pair_mask = k_indices < n_corr
                if is_correlated_pair_mask.any():
                    corr_indices_mask = torch.where(is_correlated_pair_mask)[0]
                    corr_n_indices = n_indices[corr_indices_mask]
                    corr_m_indices = m_indices[corr_indices_mask]
                    corr_k_indices = k_indices[corr_indices_mask]
                    num_corr_to_activate = len(corr_k_indices)
                    is_first_half = corr_k_indices < N
                    k_partners = torch.where(
                        is_first_half, corr_k_indices + N, corr_k_indices - N
                    )
                    rand_vals2 = torch.rand(num_corr_to_activate, device=DEVICE)
                    if self.cfg.data_range == (-1, 1):
                        rand_vals2 = 2 * rand_vals2 - 1
                    batch[corr_n_indices, corr_m_indices, k_partners] = rand_vals2
        else:
            raise ValueError(
                f"Invalid feature_generation_mode: {self.cfg.feature_generation_mode}"
            )
        # Assign labels based on the generated batch (applies to both modes)
        labels, _ = self._assign_class(batch)
        if instance_idx is not None:
            return batch[:, instance_idx, :], labels[:, instance_idx]
        return batch, labels

    def calculate_loss(
        self,
        out: Float[Tensor, "batch instances classes"],
        labels: Float[Tensor, "batch instances"],
        instance_idx: Optional[int] = None,
        **kwargs,
    ) -> Float[Tensor, ""]:
        """Calculate loss between predictions and labels."""
        # Only classification loss
        if instance_idx is not None:
            return self._calculate_loss_single_instance(out, labels)
        return self._calculate_loss_batch(out, labels)

    def _calculate_loss_single_instance(
        self, out: Float[Tensor, "batch classes"], labels: Float[Tensor, "batch"]
    ) -> Float[Tensor, ""]:
        """Calculate loss for a single instance."""
        if self.loss_fn == "mse":
            labels_one_hot = F.one_hot(labels, num_classes=out.shape[-1]).float()
            pred_probs = F.softmax(out, dim=-1)
            loss = F.mse_loss(pred_probs, labels_one_hot, reduction="none")
            # loss = F.mse_loss(out, labels_one_hot, reduction="none")
            if self.class_importance is not None:
                loss = loss * self.class_importance.unsqueeze(0)
            loss = loss.sum(dim=-1)
        else:  # "ce" or default
            loss = F.cross_entropy(
                out, labels, weight=self.class_importance, reduction="none"
            )
        return loss.mean()

    def _calculate_loss_batch(
        self,
        out: Float[Tensor, "batch instances classes"],
        labels: Float[Tensor, "batch instances"],
    ) -> Float[Tensor, ""]:
        """Calculate loss for a batch of instances."""
        logits_flat = rearrange(out, "b i c -> (b i) c")
        labels_flat = rearrange(labels, "b i -> (b i)")

        if self.loss_fn == "mse":
            labels_one_hot = F.one_hot(
                labels_flat, num_classes=logits_flat.shape[-1]
            ).float()
            pred_probs = F.softmax(logits_flat, dim=-1)
            loss = F.mse_loss(pred_probs, labels_one_hot, reduction="none")
            # loss = F.mse_loss(logits_flat, labels_one_hot, reduction="none")
            if self.class_importance is not None:
                loss = loss * self.class_importance.unsqueeze(0)
            loss = loss.sum(dim=-1)
        else:  # "ce" or default
            loss = F.cross_entropy(
                logits_flat, labels_flat, weight=self.class_importance, reduction="none"
            )

        loss = rearrange(loss, "(a b) -> a b", a=out.shape[0], b=out.shape[1])
        return loss.mean(dim=0).sum()

    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 100,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
        instance_idx: Optional[int] = None,
        optimizer: str = "adam",
    ) -> None:
        """Optimize the model parameters."""
        optimizer = (
            t.optim.AdamW(self.parameters(), lr=lr)
            if optimizer == "adam"
            else t.optim.SGD(self.parameters(), lr=lr)
        )

        progress_bar = tqdm(range(steps), desc="Training", miniters=1, ascii=True)

        for step in progress_bar:
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            optimizer.zero_grad()
            batch, labels = self.generate_batch(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, labels, instance_idx=instance_idx)
            loss.backward()
            optimizer.step()

            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(
                    loss=loss.item() / self.cfg.n_instances, lr=step_lr
                )

    def test_accuracy(
        self,
        batch_size: int = 10_000,
        batch: Optional[Tensor] = None,
        instance_idx: Optional[int] = None,
    ) -> Tensor:
        """Calculate model accuracy."""
        if batch is not None:
            labels, _ = self._assign_class(batch, single_index=True)
            out = self(batch, instance_idx=instance_idx)
        else:
            batch, labels = self.generate_batch(batch_size)
            out = self(batch)
        return t.sum(labels == out.argmax(-1), dim=0) / batch.shape[0], out, labels

    def attack(
        self,
        attack_method: str = "linf",
        batch_size: int = 1024,
        verbose: bool = False,
        instance_idx: Optional[int] = None,
        attack_params: Optional[Dict[str, Any]] = None,
        batch: Optional[Tensor] = None,
        target_class: Optional[int] = None,
        find_worst_case: bool = False,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """Generate adversarial examples."""
        if batch is not None:
            batch, labels = batch, self._assign_class(batch, single_index=True)[0]
        else:
            batch, labels = self.generate_batch(
                batch_size=batch_size, instance_idx=instance_idx
            )

        if attack_params is None:
            attack_params = {
                "num_iter": 1000,
                "alpha": 0.01,
            }

        if attack_method == "l2":
            if not attack_params.get("epsilon"):
                attack_params["epsilon"] = 1.0
            attack_fn = pgd_l2_adv
        elif attack_method == "linf":
            if not attack_params.get("epsilon"):
                attack_params["epsilon"] = 0.1
            attack_fn = pgd_linf_adv
        else:
            raise NotImplementedError("Only linf and l2 attacks are supported")

        target_classes = t.ones_like(labels) * target_class if target_class else None

        train_delta, successful_attack_mask, _ = attack_fn(
            self,
            batch,
            labels,
            target_classes,
            **attack_params,
            verbose=verbose,
            loss_fn=self.calculate_loss,
            instance_idx=instance_idx,
            find_worst_case=find_worst_case,
        )

        attack_data = batch + train_delta
        return attack_data, batch, labels, successful_attack_mask
