import torch
import torch.nn.functional as F
from torch.autograd import grad
from typing import Dict, Any, Optional
import numpy as np


class LipschitzRegularizer:

    def __init__(self, config: Dict[str, Any]):
        self.p_norm = config.get('p_norm', 2)
        self.power_iters = config.get('power_iters', 1)
        self.lambda_lip = config.get('lambda_lip', 1.0)
        self.enabled = config.get('enabled', True)
        
    def compute_local_lipschitz(self, model, data, loss_fn=None):
        if not self.enabled:
            return torch.tensor(0.0, device=data.x.device)
        
        x = data.x.clone().detach().requires_grad_(True)
        s = data.s.clone().detach() if hasattr(data, 's') else None
        
        data_modified = data._replace(x=x)
        if s is not None:
            data_modified = data_modified._replace(s=s)
            
        logits = model(data_modified)
        
        if loss_fn is None:
            loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
        else:
            loss = loss_fn(logits[data.train_mask], data.y[data.train_mask])
        
        inputs = [x]
        if a is not None:
            inputs.append(a)
            
        gradients = grad(loss, inputs, create_graph=True, retain_graph=True)
        
        local_lips = []
        
        for i, g in enumerate(gradients):
            if g is not None:
                v = torch.randn_like(g)
                for _ in range(self.power_iters):
                    if g.requires_grad:
                        hv = grad((g * v).sum(), inputs[i], retain_graph=True)[0]
                        v = hv / (hv.norm() + 1e-8)
                    else:
                        v = v / (v.norm() + 1e-8)
                
                local_lip = (g * v).sum().abs()
                local_lips.append(local_lip)
        
        if len(local_lips) > 0:
            local_lips_tensor = torch.stack(local_lips)
            if self.p_norm == 1:
                global_lip = local_lips_tensor.sum()
            elif self.p_norm == 2:
                global_lip = torch.sqrt((local_lips_tensor ** 2).sum())
            elif self.p_norm == float('inf'):
                global_lip = local_lips_tensor.max()
            else:
                global_lip = (local_lips_tensor ** self.p_norm).sum() ** (1.0 / self.p_norm)
        else:
            global_lip = torch.tensor(0.0, device=data.x.device)
            
        return global_lip
    
    def compute_regularization_term(self, model, data, r: float):
        local_lip = self.compute_local_lipschitz(model, data)
        return self.lambda_lip * local_lip * torch.sqrt(torch.tensor(r, device=data.x.device))


class GraphDROLoss:

    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.lipschitz_reg = LipschitzRegularizer(config.get('lipschitz', {}))
        self.kappa = config.get('kappa', {})
        self.fairness_config = config.get('fairness', {})
        self.dro_config = config.get('dro', {})
        
        self.kappa_feature = self.kappa.get('feature', 1.0)
        self.kappa_edge = self.kappa.get('edge', 0.5)
        self.kappa_sensitive = self.kappa.get('sensitive', 2.0)
        self.kappa_label = self.kappa.get('label', 0.1)
        
        self.alpha = self.fairness_config.get('alpha', 0.5)
        self.beta = self.fairness_config.get('beta', 1.0)
        
    def compute_base_loss(self, model, data):
        logits = model(data.edge_index, data.x)
        return F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
    
    def compute_fairness_loss(self, model, data):
        logits = model(data.edge_index, data.x)
        preds = torch.softmax(logits, dim=1)
        
        if hasattr(data, 's'):
            sens_0_mask = (data.s == 0) & data.train_mask
            sens_1_mask = (data.s == 1) & data.train_mask
            
            if sens_0_mask.sum() > 0 and sens_1_mask.sum() > 0:
                pred_rate_0 = preds[sens_0_mask, 1].mean()
                pred_rate_1 = preds[sens_1_mask, 1].mean()
                fairness_loss = torch.abs(pred_rate_0 - pred_rate_1)
            else:
                fairness_loss = torch.tensor(0.0, device=data.x.device)
        else:
            fairness_loss = torch.tensor(0.0, device=data.x.device)
            
        return fairness_loss
    
    def compute_dro_loss(self, model, data, lambda_param: torch.Tensor, r: float):
        base_loss = self.compute_base_loss(model, data)
        
        lip_reg = self.lipschitz_reg.compute_regularization_term(model, data, r)
        
        fairness_loss = self.compute_fairness_loss(model, data)
        
        dro_term = lambda_param * r + lip_reg
        fairness_term = self.alpha * fairness_loss
        
        total_loss = base_loss + dro_term + fairness_term
        
        loss_components = {
            'base_loss': base_loss.item(),
            'dro_term': dro_term.item(),
            'lip_reg': lip_reg.item(),
            'fairness_loss': fairness_loss.item(),
            'total_loss': total_loss.item(),
            'lambda': lambda_param.item()
        }
        
        return total_loss, loss_components


def dro_loss(model, data, lambda_param, r, config=None):
    if config is None:
        return _simple_dro_loss(model, data, lambda_param, r)
    else:
        dro_loss_fn = GraphDROLoss(config)
        total_loss, _ = dro_loss_fn.compute_dro_loss(model, data, lambda_param, r)
        return total_loss


def _simple_dro_loss(model, data, lambda_param, r):
    logits = model(data.edge_index, data.x)
    ce_loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
    
    x = data.x.clone().detach().requires_grad_(True)
    data_modified = data._replace(x=x)
    logits_grad = model(data_modified)
    loss_grad = F.cross_entropy(logits_grad[data.train_mask], data.y[data.train_mask])
    
    g = grad(loss_grad, x, retain_graph=False)[0]
    
    v = torch.randn_like(g)
    hv = grad((g * v).sum(), x, retain_graph=True)[0]
    v = hv / (hv.norm() + 1e-8)
    
    lip_const = (g * v).sum().abs()
    
    return ce_loss + lambda_param * r + lambda_param * lip_const


def create_loss_function(config: Dict[str, Any]):
    return GraphDROLoss(config.get('training', {}))


from .metrics import compute_fairness_metrics