import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from typing import Dict, Any, Tuple, Optional
import numpy as np


class GraphDRODualSolver:

    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.r = config['training']['r']
        self.kappa = config['training']['kappa']
        
        dro_config = config['training']['dro']
        self.lambda_init = dro_config.get('lambda_init', 1.0)
        self.lambda_clamp_min = dro_config.get('lambda_clamp_min', 0.0)
        self.lambda_clamp_max = dro_config.get('lambda_clamp_max', 10.0)
        
        self.lambda_lr = config['optimizer'].get('lambda_lr', 0.01)
        
        self.lambda_param = None
        self.lambda_optimizer = None
        
    def initialize_lambda(self, device):
        self.lambda_param = nn.Parameter(
            torch.tensor(self.lambda_init, dtype=torch.float32, device=device)
        )
        self.lambda_optimizer = torch.optim.Adam([self.lambda_param], lr=self.lambda_lr)
        
    def compute_wasserstein_distance(self, data_clean, data_perturbed):
        distance = 0.0
        
        if hasattr(data_clean, 'edge_index') and hasattr(data_perturbed, 'edge_index'):
            adj_clean = self._edge_index_to_adj(data_clean.edge_index, data_clean.num_nodes)
            adj_perturbed = self._edge_index_to_adj(data_perturbed.edge_index, data_perturbed.num_nodes)
            d_A = torch.norm(adj_clean - adj_perturbed, p='fro')**2
            distance += d_A
        
        if hasattr(data_clean, 'x') and hasattr(data_perturbed, 'x'):
            d_X = self.kappa['feature'] * torch.norm(data_clean.x - data_perturbed.x, p='fro')**2
            distance += d_X
        
        if hasattr(data_clean, 's') and hasattr(data_perturbed, 's'):
            d_s = self.kappa['sensitive'] * torch.norm(data_clean.s - data_perturbed.s, p=2)**2
            distance += d_s
        
        if hasattr(data_clean, 'y') and hasattr(data_perturbed, 'y'):
            d_y = self.kappa['label'] * torch.norm(data_clean.y - data_perturbed.y, p=2)**2
            distance += d_y
            
        return distance
    
    def _edge_index_to_adj(self, edge_index, num_nodes):
        adj = torch.zeros(num_nodes, num_nodes, device=edge_index.device)
        adj[edge_index[0], edge_index[1]] = 1.0
        return adj
    
    def compute_phi_lambda(self, model, data, lambda_param):
        device = data.x.device
        
        phi_values = []
        
        train_indices = torch.where(data.train_mask)[0]
        
        for idx in train_indices:
            phi_i = self._solve_local_inf_problem(model, data, idx, lambda_param)
            phi_values.append(phi_i)
        
        if len(phi_values) > 0:
            return torch.stack(phi_values).mean()
        else:
            return torch.tensor(0.0, device=device)
    
    def _solve_local_inf_problem(self, model, data, node_idx, lambda_param):
        perturbed_data = data.clone()
        
        delta_x = torch.zeros_like(data.x[node_idx:node_idx+1], requires_grad=True)
        if hasattr(data, 's'):
            delta_s = torch.zeros_like(data.s[node_idx:node_idx+1].float(), requires_grad=True)
        else:
            delta_s = None
        
        params = [delta_x]
        if delta_s is not None:
            params.append(delta_s)
        
        inner_optimizer = torch.optim.SGD(params, lr=0.01)
        
        inner_optimizer.zero_grad()
        
        perturbed_x = data.x.clone()
        perturbed_x[node_idx] = data.x[node_idx] + delta_x.squeeze()
        perturbed_data.x = perturbed_x
        
        if delta_s is not None:
            perturbed_s = data.s.clone()
            perturbed_s[node_idx] = data.s[node_idx] + delta_s.squeeze()
            perturbed_data.s = perturbed_s
        
        logits = model(perturbed_data)
        loss = F.cross_entropy(logits[node_idx:node_idx+1], data.y[node_idx:node_idx+1])
        
        distance = 0.0
        distance += self.kappa['feature'] * torch.norm(delta_x, p=2)**2
        if delta_s is not None:
            distance += self.kappa['sensitive'] * torch.norm(delta_s, p=2)**2
        
        objective = lambda_param * distance - loss
        
        return objective.detach()
    
    def solve_dual_objective(self, model, data, theta_optimizer):
        device = data.x.device
        
        if self.lambda_param is None:
            self.initialize_lambda(device)
        
        phi_lambda = self.compute_phi_lambda(model, data, self.lambda_param)
        
        dual_objective = self.lambda_param * self.r - phi_lambda
        
        theta_optimizer.zero_grad()
        (-dual_objective).backward(retain_graph=True)
        theta_optimizer.step()
        
        self.lambda_optimizer.zero_grad()
        dual_objective.backward()
        self.lambda_optimizer.step()
        
        with torch.no_grad():
            self.lambda_param.data.clamp_(
                min=self.lambda_clamp_min, 
                max=self.lambda_clamp_max
            )
        
        return dual_objective.item(), self.lambda_param.item()
    
    def get_regularization_upper_bound(self, model, data):
        logits = model(data.edge_index, data.x)
        base_loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
        
        lipschitz_constant = self._estimate_lipschitz_constant(model, data)
        
        upper_bound = base_loss + lipschitz_constant * torch.sqrt(torch.tensor(self.r, device=data.x.device))
        
        return upper_bound, base_loss.item(), lipschitz_constant.item()
    
    def _estimate_lipschitz_constant(self, model, data):
        data.x.requires_grad_(True)

        logits = model(data.edge_index, data.x)
        loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
        
        grad_x = grad(loss, data.x, create_graph=False, retain_graph=True)[0]
        
        lipschitz_est = torch.sqrt(
            torch.norm(grad_x, p='fro')**2 / self.kappa['feature']
        )
        
        return lipschitz_est