from typing import Callable, Optional, Tuple, Union

import einops
import torch
import torch as t
from einops import rearrange, reduce
from jaxtyping import Float
from torch import Tensor, nn
from torch.nn import functional as F
from tqdm import tqdm

from adversarial_superposition.constants import DEVICE
from adversarial_superposition.toy_models.utils.model import (
    ClassificationConfig,
    constant_lr,
)
from adversarial_superposition.toy_models.utils.pgd_attack import (
    pgd_l2_adv,
    pgd_linf_adv,
)


class OrthogonalClassificationnModel(nn.Module):
    def __init__(
        self,
        cfg: ClassificationConfig,
        feature_probability: Optional[Union[float, Tensor]] = None,
        class_importance: Optional[Union[float, Tensor]] = None,
        feature_importance: Optional[Union[float, Tensor]] = None,
        privileged=False,
        privileged_out=False,
        bias=False,
        projection_bias=False,
        abs=False,
        loss_fn="ce",
        device="cpu",
        orthog_dim=0,
    ):
        super().__init__()

        self.cfg = cfg

        self.loss_fn = loss_fn
        self.abs = abs

        self.privileged = privileged
        self.privileged_out = privileged_out

        if feature_probability is None:
            feature_probability = t.ones(())

        if isinstance(feature_probability, float):
            feature_probability = t.tensor(feature_probability)

        self.feature_probability = (
            feature_probability.to(device)
            .broadcast_to((cfg.n_instances, cfg.n_features))
            .to(DEVICE)
        )

        if (class_importance is not None) and (feature_importance is not None):
            raise ValueError("Can only have either class or feature importance")

        if class_importance is None:
            # class_importance = t.ones((cfg.n_instances, cfg.n_classes))
            class_importance = t.ones((cfg.n_classes))
            print(
                f"Set class importance to vector of ones shape: {class_importance.shape}"
            )

        if feature_importance is None:
            feature_importance = t.ones((cfg.n_instances, cfg.n_features_per_class))
            print(
                f"Set feature importance to vector of ones shape: {feature_importance.shape}"
            )

        if isinstance(class_importance, float):
            class_importance = t.tensor(class_importance)

        if isinstance(feature_importance, float):
            feature_importance = t.tensor(feature_importance)

        self.feature_importance = feature_importance.to(device)
        self.class_importance = class_importance.to(device)

        self.W_orthog = nn.Parameter(
            t.ones((cfg.n_instances, cfg.n_features_per_class))
        )
        self.W_non_orthog = nn.Parameter(
            nn.init.xavier_normal_(
                t.empty(
                    (
                        cfg.n_instances,
                        cfg.n_hidden - 1,
                        cfg.n_features - cfg.n_features_per_class,
                    )
                )
            )
        )
        self.register_buffer(
            "W", t.zeros((cfg.n_instances, cfg.n_hidden, cfg.n_features))
        )
        self.preserved_dims = range(orthog_dim, orthog_dim + cfg.n_features_per_class)

        self.bias = bias
        self.projection_bias = projection_bias
        self.b_enc = nn.Parameter(t.zeros((cfg.n_instances, cfg.n_hidden)))
        self.b_projection = nn.Parameter(t.zeros((cfg.n_instances, cfg.n_classes)))

        self.W_projection = nn.Parameter(
            nn.init.xavier_normal_(
                t.empty((cfg.n_instances, cfg.n_classes, cfg.n_hidden))
            )
        )
        self.to(DEVICE)

    def _construct_weight_matrix(self):
        """Constructs the full weight matrix ensuring orthogonality for the preserved dimension"""

        batch_size = self.W_orthog.shape[0]

        # Start with zeros
        W = t.zeros_like(self.W)

        # Set up the first row: zeros everywhere except at preserved_dim
        W[:, 0, :] = 0
        W[:, 0, self.preserved_dims] = self.W_orthog

        # Set up the mask for all other positions
        # Create a boolean mask for non-preserved dimensions
        mask = t.ones(self.cfg.n_features, device=self.W.device)
        mask[self.preserved_dims] = 0
        non_preserved_mask = mask.bool()

        non_preserved_indices = t.where(non_preserved_mask)[0]

        W[:, self.cfg.n_features_per_class :, non_preserved_indices] = self.W_non_orthog

        return W

    def forward(
        self, features: Float[Tensor, "... instances features"], instance_idx=None
    ) -> Float[Tensor, "... instances features"]:
        self.W = self._construct_weight_matrix()
        features = features.to(DEVICE)

        if instance_idx is not None:
            out = einops.einsum(
                features,
                self.W[instance_idx, ...],
                "... features, hidden features -> ... hidden",
            )

            if self.bias:
                out = out + self.b_enc[:, instance_idx, ...]

            if self.privileged:
                activations = F.relu(out)
            else:
                activations = out

            # Setup with direct classification
            classification = einops.einsum(
                activations,
                self.W_projection[instance_idx, ...],
                "... features, n_classes features -> ... n_classes",
            )

            if self.projection_bias:
                classification = (
                    classification + self.b_projection[:, instance_idx, ...]
                )

            if self.privileged_out:
                classification = F.relu(classification)
        else:
            out = einops.einsum(
                features,
                self.W,
                "... instances features, instances hidden features -> ... instances hidden",
            )

            if self.bias:
                out = out + self.b_enc

            if self.privileged:
                activations = F.relu(out)
            else:
                activations = out

            # Setup with direct classification
            classification = einops.einsum(
                activations,
                self.W_projection,
                "... instances features, instances n_classes features -> ... instances n_classes",
            )

            if self.projection_bias:
                classification = classification + self.b_projection

            if self.privileged_out:
                classification = F.relu(classification)

        return classification.to(DEVICE)

    def generate_uncorrelated_features(
        self, batch_size, n_uncorrelated
    ) -> Float[Tensor, "batch_size instances features"]:
        """Generates a batch of uncorrelated features."""

        feat = t.rand(
            (batch_size, self.cfg.n_instances, n_uncorrelated), device=self.W.device
        )
        feat_seeds = t.rand(
            (batch_size, self.cfg.n_instances, n_uncorrelated), device=self.W.device
        )
        # Use the corresponding sparsity to make features 0
        feat_is_present = feat_seeds <= self.feature_probability[:, [0]]
        features = t.where(feat_is_present, feat, 0.0).to(DEVICE)

        # If a row is all 0s, artificially create 1 feature to a positive value
        mask = (features == 0).all(dim=2).to(DEVICE)
        if mask.any():
            # Generate random indices for k dimension
            k_indices = torch.randint(0, features.shape[2], (mask.sum(),)).to(DEVICE)
            # Get indices where mask is True
            n_indices, m_indices = torch.where(mask)
            n_indices, m_indices = n_indices.to(DEVICE), m_indices.to(DEVICE)
            features[n_indices, m_indices, k_indices] = torch.rand(mask.sum()).to(
                DEVICE
            )

        if self.abs:
            features = randomize_signs_torch(features)

        return features.to(DEVICE)

    def _assign_class(self, batch, single_index=False):
        """Each batch has (classes x features) dimensions.

        Args:
            batch - shape <batch, instances, features>"""

        if self.cfg.n_features % self.cfg.n_classes != 0:
            raise ValueError(
                "There needs to be an equal amount of features in each " "class"
            )

        if single_index:
            rearrange_str = "batch (classes features) -> batch classes features"
        else:
            rearrange_str = (
                "batch instances (classes features) -> batch instances classes features"
            )

        if self.abs:
            batch = torch.abs(batch).to(DEVICE)

        reshaped = rearrange(
            batch,
            rearrange_str,
            classes=self.cfg.n_classes,
        )
        # The first f numbers belong to c=1, and so on. Sum the features for each class.
        if single_index:
            reduce_str = "batch classes features -> batch classes"
        else:
            reduce_str = "batch instances classes features -> batch instances classes"

        class_sums = reduce(
            reshaped,
            reduce_str,
            "sum",
        )

        # Get the index of the highest sum for each instance
        labels = torch.argmax(class_sums, dim=-1)

        return labels, class_sums

    def generate_batch(self, batch_size, instance_idx=None) -> Tuple[Tensor, Tensor]:
        """Generates a batch of data and it's labels"""
        batch = self.generate_uncorrelated_features(batch_size, self.cfg.n_features)
        labels, _ = self._assign_class(batch)

        if instance_idx is not None:
            return batch[:, instance_idx, :], labels[:, instance_idx]

        return batch.to(DEVICE), labels.to(DEVICE)

    def calculate_loss(
        self,
        out: Float[Tensor, "batch instances features"],
        labels: Float[Tensor, "batch instances features"],
        instance_idx=None,
        **kwargs,
    ) -> Float[Tensor, ""]:
        if instance_idx is not None:
            if self.loss_fn == "mse":
                # Convert labels to one-hot and apply softmax to predictions
                labels_one_hot = (
                    F.one_hot(labels, num_classes=out.shape[-1]).float().to(DEVICE)
                )
                pred_probs = F.softmax(out, dim=-1)
                loss = F.mse_loss(pred_probs, labels_one_hot, reduction="none")
                if self.class_importance is not None:
                    loss = loss * self.class_importance.unsqueeze(0)
                # Sum across class dimension
                loss = loss.sum(dim=-1)
            else:  # "ce" or default
                loss = F.cross_entropy(
                    out, labels, weight=self.class_importance, reduction="none"
                )
        else:
            logits_flat = rearrange(out, "b i c -> (b i) c").to(DEVICE)
            labels_flat = rearrange(labels, "b i -> (b i)").to(DEVICE)

            if self.loss_fn == "mse":
                # Convert labels to one-hot and apply softmax to predictions
                labels_one_hot = F.one_hot(
                    labels_flat, num_classes=logits_flat.shape[-1]
                ).float()
                pred_probs = F.softmax(
                    logits_flat, dim=-1
                )  # Apply softmax to flattened logits
                loss = F.mse_loss(pred_probs, labels_one_hot, reduction="none")

                if self.class_importance is not None:
                    loss = loss * self.class_importance.unsqueeze(0)

                # Sum across class dimension
                loss = loss.sum(dim=-1)
            elif self.loss_fn == "ce":
                loss = F.cross_entropy(
                    logits_flat,
                    labels_flat,
                    weight=self.class_importance,
                    reduction="none",
                )

        # Reshape the loss back to (a, b) if needed
        if instance_idx is None:
            loss = rearrange(loss, "(a b) -> a b", a=out.shape[0], b=out.shape[1])

        # Final reduction
        loss = torch.mean(loss, dim=0).sum()

        return loss

    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 100,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
        instance_idx=None,
        optimizer: str = "adam",
    ):
        """Optimizes the model using the given hyperparameters."""

        if optimizer == "adam":
            optimizer = t.optim.Adam(
                list(self.parameters()),
                lr=lr,
            )
        else:
            optimizer = t.optim.SGD(
                list(self.parameters()),
                lr=lr,
            )

        print(f"Starting training loop for {self.__class__.__name__}")

        progress_bar = tqdm(range(steps), desc="Training", miniters=1, ascii=True)

        for step in progress_bar:
            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group["lr"] = step_lr

            # Optimize
            optimizer.zero_grad()

            batch, labels = self.generate_batch(batch_size)
            out = self(batch)
            out = out.to(DEVICE)
            loss = self.calculate_loss(out, labels, instance_idx=instance_idx)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(
                    loss=loss.item() / self.cfg.n_instances, lr=step_lr
                )
                # tqdm.write(f"Loss: {loss.item() / self.cfg.n_instances}")

        print(f"Finished training loop for {self.__class__.__name__}")

    def test_accuracy(self, batch_size=10_000, batch=None, instance_idx=None):
        if batch is not None:
            labels, _ = self._assign_class(batch, single_index=True)
            out = self(batch, instance_idx=instance_idx)
            return torch.sum(labels == out.argmax(-1), dim=0) / batch.shape[0]
        else:
            batch, labels = self.generate_batch(batch_size)
            out = self(batch)
            return torch.sum(labels == out.argmax(-1), dim=0) / batch_size

    def attack(
        self,
        attack_method="linf",
        batch_size: int = 1024,
        verbose=False,
        instance_idx=None,
        attack_params=None,
        batch=None,
        target_class=None,
        find_worst_case: bool = False,
    ):
        if batch != None:
            batch, labels = batch, self._assign_class(batch, single_index=True)
        else:
            batch, labels = self.generate_batch(
                batch_size=batch_size, instance_idx=instance_idx
            )

        if not attack_params:
            attack_params = {
                "num_iter": 1000,
                "alpha": 0.01,
            }

        if attack_method == "l2":
            if not attack_params.get("epsilon"):
                attack_params.update({"epsilon": 1.0})
            attack_fn = pgd_l2_adv
        elif attack_method == "linf":
            if not attack_params.get("epsilon"):
                attack_params.update({"epsilon": 0.1})
            attack_fn = pgd_linf_adv
        else:
            raise NotImplementedError("Only linf and l2 attacks are supported")

        if target_class:
            target_classes = torch.ones_like(labels) * target_class
        else:
            target_classes = None

        train_delta, successful_attack_mask, _ = attack_fn(
            self,
            batch,
            labels,
            target_classes,
            **attack_params,
            verbose=verbose,
            loss_fn=self.calculate_loss,
            instance_idx=instance_idx,
            find_worst_case=find_worst_case,
        )

        attack_data = batch + train_delta

        return attack_data, batch, labels, successful_attack_mask
