"""
Model definitions for HKAN experiments
"""
import torch
import torch.nn as nn
import numpy as np

# Import official KAN library
try:
    from kan import KAN
    print("Official KAN library imported successfully")
except ImportError:
    print("KAN library not found. Please install: pip install pykan")
    exit(1)


def calculate_factor_quality_score(factors, config):
    """Calculate Factor Quality Score (FQS)"""
    if factors is None or len(factors) < 2:
        return 0.5

    factors_array = np.array(factors).T
    if factors_array.shape[1] < 2:
        return 0.5

    try:
        correlation_matrix = np.corrcoef(factors_array.T)
        if np.any(np.isnan(correlation_matrix)):
            return 0.5

        # Independence score
        off_diagonal = correlation_matrix[np.triu_indices_from(correlation_matrix, k=1)]
        independence_score = 1 - np.mean(np.abs(off_diagonal))

        # Stability score
        factor_vars = np.var(factors_array, axis=0)
        stability_score = np.exp(-np.mean(np.log(1 + factor_vars)))

        # Sparsity score
        tau = np.mean(np.abs(factors_array)) * 0.1
        sparsity_activations = np.mean(np.abs(factors_array) > tau, axis=0)
        sparsity_score = 1.0 - np.mean(sparsity_activations)

        # Calculate FQS
        fqs = (config.fqs_independence_weight * independence_score +
               config.fqs_stability_weight * stability_score +
               config.fqs_sparsity_weight * sparsity_score)
        return fqs
    except:
        return 0.5


class SubKANOfficial(nn.Module):
    """Sub-KAN module using official KAN library"""

    def __init__(self, input_dim, config, device='cpu'):
        super(SubKANOfficial, self).__init__()
        self.config = config
        self.device = device

        # Configure hidden dimension
        hidden_dim = config.kan_hidden_multiplier * input_dim + 1
        kan_config = [input_dim, hidden_dim, 1]
        self.kan = KAN(kan_config, grid=config.kan_grid, k=config.kan_k, device=device, seed=config.seed)

    def forward(self, x):
        return self.kan(x)

    def get_regularization_loss(self):
        return self.kan.reg(
            'edge_forward_spline_n',
            self.config.kan_reg_lamb_l1,
            self.config.kan_reg_lamb_entropy,
            self.config.kan_reg_lamb_coef,
            self.config.kan_reg_lamb_coefdiff
        )


class HKANClassification(nn.Module):
    """Hierarchical KAN for binary classification"""

    def __init__(self, group_feature_info, config, device='cpu'):
        super(HKANClassification, self).__init__()
        self.group_names = list(group_feature_info.keys())
        self.num_groups = len(self.group_names)
        self.config = config
        self.device = device

        # Create sub-KAN networks
        self.sub_kans = nn.ModuleDict({
            name: SubKANOfficial(info['input_dim'], config, device)
            for name, info in group_feature_info.items()
        })

        # Fusion network using official KAN
        fusion_config = [self.num_groups] + config.fusion_hidden_layers + [1]
        self.fusion_kan = KAN(fusion_config, grid=config.kan_grid, k=config.kan_k, device=device, seed=config.seed)

        self.last_combined_factors_for_reg = None
        self.last_factors_for_fqs = None

    def forward(self, group_inputs):
        # Compute outputs from all sub-networks
        factors = []
        for name in self.group_names:
            factor = self.sub_kans[name](group_inputs[name])
            factors.append(factor)

        # Combine factors
        combined_factors = torch.cat(factors, dim=1)
        self.last_combined_factors_for_reg = combined_factors
        self.last_factors_for_fqs = [f.detach().cpu().numpy().flatten() for f in factors]

        # Pass through fusion network
        if self.num_groups == 1:
            return combined_factors
        else:
            return self.fusion_kan(combined_factors)

    def get_factor_regularization_loss(self):
        """Calculate factor regularization loss"""
        if self.last_combined_factors_for_reg is None:
            return torch.tensor(0.0, device=self.device)

        factors = self.last_combined_factors_for_reg
        factor_reg = torch.tensor(0.0, device=self.device)

        # Decorrelation regularization
        if factors.shape[1] > 1:
            corr_matrix = torch.corrcoef(factors.T)
            if not torch.isnan(corr_matrix).any():
                factor_reg += self.config.lambda_decorrelation * torch.sum(torch.triu(corr_matrix, diagonal=1)**2)

        # Sparsity regularization
        factor_reg += self.config.lambda_sparsity * torch.mean(torch.abs(factors))

        # Stability regularization
        factor_reg += self.config.lambda_stability * torch.sum(torch.var(factors, dim=0))

        return factor_reg

    def get_total_regularization_loss(self):
        """Calculate total regularization loss"""
        total_reg_loss = torch.tensor(0.0, device=self.device)

        # Sub-KAN regularization losses
        for kan in self.sub_kans.values():
            total_reg_loss += kan.get_regularization_loss()

        # Fusion KAN regularization loss
        if self.num_groups > 1:
            total_reg_loss += self.fusion_kan.reg(
                'edge_forward_spline_n',
                self.config.kan_reg_lamb_l1,
                self.config.kan_reg_lamb_entropy,
                self.config.kan_reg_lamb_coef,
                self.config.kan_reg_lamb_coefdiff
            )

        # Factor regularization loss
        total_reg_loss += self.get_factor_regularization_loss()

        return total_reg_loss

    def count_parameters(self):
        """Count trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


class PureKANClassifier(nn.Module):
    """Pure KAN classifier without hierarchical structure"""

    def __init__(self, input_dim, config, device='cpu'):
        super(PureKANClassifier, self).__init__()
        self.config = config
        self.device = device

        # Create KAN network: [input_dim, hidden_size, 1]
        hidden_size = getattr(config, 'hidden_size', 32)  # Default hidden size
        kan_config = [input_dim, hidden_size, 1]
        self.kan = KAN(kan_config, grid=config.kan_grid, k=config.kan_k, device=device, seed=config.seed)

        print(f"Pure KAN architecture: {kan_config}")
        print(f"Input dimension: {input_dim}, Hidden dimension: {hidden_size}")

    def forward(self, x):
        return self.kan(x)

    def get_regularization_loss(self):
        """Get KAN regularization loss"""
        return self.kan.reg(
            'edge_forward_spline_n',
            self.config.kan_reg_lamb_l1,
            self.config.kan_reg_lamb_entropy,
            self.config.kan_reg_lamb_coef,
            self.config.kan_reg_lamb_coefdiff
        )

    def count_parameters(self):
        """Count trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)