import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from einops import rearrange, repeat
from typing import List, Tuple, Dict, Any

from .base_classifier import BaseMILClassifier, PredictionResult
from .classifier_factory import register_classifier

_c2r = torch.view_as_real
_r2c = torch.view_as_complex


class DropoutNd(nn.Module):
    def __init__(self, p: float = 0.5, tie=True, transposed=True):
        super().__init__()
        if p < 0 or p >= 1:
            raise ValueError(f"dropout probability has to be in [0, 1), but got {p}")
        self.p = p
        self.tie = tie
        self.transposed = transposed

    def forward(self, X):
        if self.training:
            if not self.transposed:
                X = rearrange(X, 'b d ... -> b ... d')
            mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
            mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
            X = X * mask * (1.0/(1-self.p))
            if not self.transposed:
                X = rearrange(X, 'b ... d -> b d ...')
            return X
        return X


class S4DKernel(nn.Module):
    def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
        super().__init__()
        H = d_model
        log_dt = torch.rand(H) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
        
        C = torch.randn(H, N // 2, dtype=torch.cfloat)
        self.C = nn.Parameter(_c2r(C))
        self.register("log_dt", log_dt, lr)
        
        log_A_real = torch.log(0.5 * torch.ones(H, N//2))
        A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        dt = torch.exp(self.log_dt)
        C = _r2c(self.C)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag
        
        dtA = A * dt.unsqueeze(-1)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device)
        C = C * (torch.exp(dtA)-1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real
        return K

    def register(self, name, tensor, lr=None):
        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))
            optim = {"weight_decay": 0.0}
            if lr is not None:
                optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class S4D(nn.Module):
    
    def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args):
        super().__init__()
        self.h = d_model
        self.n = d_state
        self.d_output = self.h
        self.transposed = transposed
        
        self.D = nn.Parameter(torch.randn(self.h))
        self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)
        
        self.activation = nn.GELU()
        dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()
        
        self.output_linear = nn.Sequential(
            nn.Conv1d(self.h, 2*self.h, kernel_size=1),
            nn.GLU(dim=-2),
        )

    def forward(self, u, **kwargs):
        if not self.transposed:
            u = u.transpose(-1, -2)
        L = u.size(-1)

        k = self.kernel(L=L)
        k_f = torch.fft.rfft(k, n=2*L)
        u_f = torch.fft.rfft(u.to(torch.float32), n=2*L)
        y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L]

        y = y + u * self.D.unsqueeze(-1)
        
        y = self.dropout(self.activation(y))
        y = self.output_linear(y)
        if not self.transposed:
            y = y.transpose(-1, -2)
        return y


class S4MILNet(nn.Module):
    
    def __init__(self, feature_dim, hidden_dim, n_classes, dropout_rate, d_state):
        super().__init__()
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.n_classes = n_classes
        
        self.feature_projection = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
        )
        
        self.s4_block = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            S4D(d_model=hidden_dim, d_state=d_state, transposed=False)
        )
        
        self.classifier = nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        """
        Args:
            x: [batch_size, n_patches, feature_dim]
        Returns:
            logits: [batch_size, n_classes]
        """
        # Feature projection
        x = self.feature_projection(x)  # [B, N, hidden_dim]

        # S4 processing
        x = self.s4_block(x)  # [B, N, hidden_dim]

        # Global pooling
        x = torch.max(x, dim=1)[0]  # [B, hidden_dim]

        # Classification
        logits = self.classifier(x)  # [B, n_classes]
        
        return logits


@register_classifier('s4mil')
class S4MILClassifier(BaseMILClassifier):
    
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        
        self.hidden_dim = config.get('hidden_dim', 512)
        self.d_state = config.get('d_state', 32)
        self.dropout_rate = config.get('dropout_rate', 0.25)
        
        self.learning_rate = float(config.get('learning_rate', 0.001))
        self.weight_decay = float(config.get('weight_decay', 1e-4))
        self.batch_size = config.get('batch_size', 8)

    def build_model(self, feature_dim: int, n_classes: int) -> nn.Module:
        return S4MILNet(
            feature_dim=feature_dim,
            hidden_dim=self.hidden_dim,
            n_classes=n_classes,
            dropout_rate=self.dropout_rate,
            d_state=self.d_state
        )

    def prepare_data(self, bags: List[Tuple[np.ndarray, Any]], labels: List[int]) -> Tuple[List[torch.Tensor], torch.Tensor]:
        bag_tensors = []
        for bag_features, _ in bags:
            bag_tensor = torch.FloatTensor(bag_features)
            bag_tensors.append(bag_tensor)

        labels_tensor = torch.LongTensor(labels)
        return bag_tensors, labels_tensor

    def train_epoch(self, train_data: List[torch.Tensor], train_labels: torch.Tensor,
                   val_data: List[torch.Tensor], val_labels: torch.Tensor,
                   epoch: int) -> Tuple[float, float, float, float]:
        if not hasattr(self, 'optimizer') or self.optimizer is None:
            self.optimizer = optim.Adam(self.model.parameters(),
                                       lr=self.learning_rate,
                                       weight_decay=self.weight_decay)
        if not hasattr(self, 'criterion') or self.criterion is None:
            self.criterion = nn.CrossEntropyLoss()

        self.model.train()

        train_losses = []
        train_correct = 0
        train_total = 0

        accumulation_steps = max(1, self.batch_size)
        accumulated_loss = 0.0

        for i, (bag_data, label) in enumerate(zip(train_data, train_labels)):
            bag_data = self._to_device(bag_data).unsqueeze(0)  # [1, n_patches, feature_dim]
            label = self._to_device(label).unsqueeze(0)  # [1]

            outputs = self.model(bag_data)
            loss = self.criterion(outputs, label)

            loss = loss / accumulation_steps

            loss.backward()

            accumulated_loss += loss.item()

            if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_data):
                self.optimizer.step()
                self.optimizer.zero_grad()

                train_losses.append(accumulated_loss)
                accumulated_loss = 0.0

            _, predicted = torch.max(outputs.data, 1)
            train_total += 1
            train_correct += (predicted == label).sum().item()

        train_loss = sum(train_losses) / len(train_losses)
        train_acc = train_correct / train_total

        self.model.eval()
        val_losses = []
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for i, (bag_data, label) in enumerate(zip(val_data, val_labels)):
                bag_data = self._to_device(bag_data).unsqueeze(0)  # [1, n_patches, feature_dim]
                label = self._to_device(label).unsqueeze(0)  # [1]

                outputs = self.model(bag_data)
                loss = self.criterion(outputs, label)

                val_losses.append(loss.item())
                _, predicted = torch.max(outputs.data, 1)
                val_total += 1
                val_correct += (predicted == label).sum().item()

        val_loss = sum(val_losses) / len(val_losses)
        val_acc = val_correct / val_total

        return train_loss, train_acc, val_loss, val_acc

    def predict_bags(self, bags: List[Tuple[np.ndarray, Any]]) -> PredictionResult:
        return self.predict(bags)

    def predict(self, bags: List[Tuple[np.ndarray, Any]]) -> PredictionResult:
        self.model.eval()

        bag_tensors, _ = self.prepare_data(bags, [0] * len(bags))

        predictions = []
        probabilities = []
        confidences = []
        bag_names = []

        with torch.no_grad():
            for i, (bag_tensor, (_, bag_name)) in enumerate(zip(bag_tensors, bags)):
                bag_data = self._to_device(bag_tensor).unsqueeze(0)  # [1, n_patches, feature_dim]

                outputs = self.model(bag_data)
                probs = F.softmax(outputs, dim=1)
                preds = torch.argmax(outputs, dim=1)
                confs = torch.max(probs, dim=1)[0]

                predictions.append(preds.cpu().numpy()[0])
                probabilities.append(probs.cpu().numpy()[0])
                confidences.append(confs.cpu().numpy()[0])
                bag_names.append(bag_name)

        return PredictionResult(
            predictions=np.array(predictions),
            probabilities=np.array(probabilities),
            confidence=np.array(confidences),
            bag_names=bag_names
        )
