import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
from typing import Dict, Any, Tuple, List, Optional
from torch_geometric.data import Data

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False

from core.perturbations import NodeCentricPerturbations
from core.metrics import compute_fairness_metrics, evaluate_robustness, compute_comprehensive_metrics


class GraphDROTrainer:
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.device = torch.device(config['training']['device'])
        self.perturbation_model = NodeCentricPerturbations(config)

        training_config = config['training']
        self.num_epochs = int(training_config.get('num_epochs', training_config.get('epochs', 200)))
        self.patience = int(training_config.get('patience', 50))
        self.early_stop = bool(training_config.get('early_stop', True))

        self.r = float(training_config['r'])
        self.kappa = training_config['kappa']
        for key in self.kappa:
            self.kappa[key] = float(self.kappa[key])

        self.eta = float(training_config.get('eta', 0.5))
        self.K = int(training_config.get('K', 5))
        self.mode = training_config.get('mode', 'unified')

        lipschitz_config = training_config.get('lipschitz', {})
        self.lipschitz_enabled = bool(lipschitz_config.get('enabled', True))
        self.p_norm = int(lipschitz_config.get('p_norm', 2))
        self.lambda_lip = float(lipschitz_config.get('lambda_lip', 1.5))

        fairness_config = training_config.get('fairness', {})
        self.alpha = float(fairness_config.get('alpha', 0.6))
        self.beta = float(fairness_config.get('beta', 1.2))

        dro_config = training_config.get('dro', {})
        self.lambda_init = float(dro_config.get('lambda_init', 1.0))
        self.lambda_clamp_min = float(dro_config.get('lambda_clamp_min', 0.0))
        self.lambda_clamp_max = float(dro_config.get('lambda_clamp_max', 15.0))

        eval_config = config.get('evaluation', {})
        self.eval_robust = bool(eval_config.get('robust', True))
        self.eval_fair = bool(eval_config.get('fair', True))

        self.use_wandb = bool(config.get('logging', {}).get('use_wandb', False))
        self.verbose = bool(config.get('logging', {}).get('verbose', True))

        self.best_val_acc = 0.0
        self.best_model_state = None
        self.training_history = []

    def setup_optimizers(self, model: nn.Module):
        optimizer_config = self.config['optimizer']
        optimizer_name = optimizer_config.get('name', 'adam')

        if optimizer_name == 'adam':
            lr = float(optimizer_config['lr'])
            weight_decay = float(optimizer_config.get('weight_decay', 1e-5))
            self.optimizer = torch.optim.Adam(
                model.parameters(),
                lr=lr,
                weight_decay=weight_decay
            )
        elif optimizer_name == 'sgd':
            lr = float(optimizer_config['lr'])
            momentum = float(optimizer_config.get('momentum', 0.9))
            weight_decay = float(optimizer_config.get('weight_decay', 1e-5))
            self.optimizer = torch.optim.SGD(
                model.parameters(),
                lr=lr,
                momentum=momentum,
                weight_decay=weight_decay
            )

        if optimizer_config.get('scheduler', {}).get('use', False):
            scheduler_config = optimizer_config['scheduler']
            if scheduler_config['name'] == 'step':
                step_size = int(scheduler_config['step_size'])
                gamma = float(scheduler_config['gamma'])
                self.scheduler = torch.optim.lr_scheduler.StepLR(
                    self.optimizer,
                    step_size=step_size,
                    gamma=gamma
                )
            elif scheduler_config['name'] == 'cosine':
                self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    self.optimizer,
                    T_max=self.num_epochs
                )
        else:
            self.scheduler = None

    def generate_mixed_empirical_distribution(self, data: Data) -> Tuple[List[Data], List[float]]:
        perturbed_graphs = self.perturbation_model.generate_k_fold_perturbations(data, self.K)
        mixed_graphs = [data] + perturbed_graphs
        weights = [self.eta] + [(1 - self.eta) / self.K] * self.K
        return mixed_graphs, weights

    def compute_lipschitz_constant(self, model: nn.Module, data: Data, node_idx: int) -> float:
        model.eval()
        if node_idx < data.train_mask.sum():
            return self._compute_first_order_lipschitz(model, data, node_idx)
        else:
            return self._compute_second_order_lipschitz(model, data, node_idx)

    def _compute_first_order_lipschitz(self, model: nn.Module, data: Data, node_idx: int) -> float:
        x_full = data.x.requires_grad_(True)
        adj_matrix = data.adj_matrix.requires_grad_(True)

        logits = model(data.edge_index, x_full)
        loss = F.cross_entropy(logits[node_idx:node_idx+1], data.y[node_idx:node_idx+1])

        gradients = torch.autograd.grad(
            loss, [adj_matrix, x_full],
            retain_graph=True, create_graph=True, allow_unused=True
        )
        grad_A = gradients[0] if gradients[0] is not None else torch.zeros_like(adj_matrix)
        grad_X_full = gradients[1] if gradients[1] is not None else torch.zeros_like(x_full)

        grad_x_n = grad_X_full[node_idx, :data.x.shape[1]]
        sens_attr_idx = int(data.sens_attr_idx) if hasattr(data, 'sens_attr_idx') else -1
        grad_s_n = grad_X_full[node_idx, sens_attr_idx:sens_attr_idx+1] if sens_attr_idx >= 0 else torch.zeros(1)

        lipschitz_terms = []
        if float(self.kappa.get('edge', 0)) > 0:
            lipschitz_terms.append(torch.norm(grad_A, p='fro')**2 / self.kappa['edge'])
        if float(self.kappa.get('feature', 0)) > 0:
            lipschitz_terms.append(torch.norm(grad_x_n, p=2)**2 / self.kappa['feature'])
        if float(self.kappa.get('sensitive', 0)) > 0:
            lipschitz_terms.append(torch.norm(grad_s_n, p=2)**2 / self.kappa['sensitive'])
        if float(self.kappa.get('label', 0)) > 0:
            lipschitz_terms.append(torch.tensor(0.0, device=data.x.device))

        lipschitz_constant = torch.sqrt(sum(lipschitz_terms))
        return lipschitz_constant.item()

    def _compute_second_order_lipschitz(self, model: nn.Module, data: Data, node_idx: int) -> float:
        try:
            x_full = data.x.requires_grad_(True)
            logits = model(data.edge_index, x_full)
            loss = torch.norm(logits[node_idx])
            gradients = torch.autograd.grad(
                loss, [x_full],
                retain_graph=True, allow_unused=True
            )
            if gradients[0] is not None:
                node_grad = gradients[0][node_idx]
                lipschitz_constant = torch.norm(node_grad, p=2)
                return lipschitz_constant.item()
            else:
                return 0.0
        except Exception:
            return 0.001

    def compute_global_lipschitz_regularizer(self, model: nn.Module, data: Data) -> float:
        num_nodes = data.num_nodes
        if num_nodes > 1000:
            sample_size = min(1000, num_nodes // 10)
            node_indices = torch.randperm(num_nodes)[:sample_size]
        else:
            node_indices = torch.arange(num_nodes)

        lipschitz_values = []
        for node_idx in node_indices:
            try:
                L_n = self.compute_lipschitz_constant(model, data, node_idx.item())
                lipschitz_values.append(L_n ** self.p_norm)
            except Exception:
                lipschitz_values.append(0.001 ** self.p_norm)

        if lipschitz_values:
            global_lipschitz = (sum(lipschitz_values) / len(lipschitz_values)) ** (1.0 / self.p_norm)
        else:
            global_lipschitz = 0.001
        return global_lipschitz

    def compute_graphdro_objective(self, model: nn.Module, mixed_graphs: List[Data],
                                   weights: List[float]) -> Tuple[float, float, float]:
        empirical_risk = 0.0
        for i, (graph, weight) in enumerate(zip(mixed_graphs, weights)):
            try:
                logits = model(graph.edge_index, graph.x)
                if graph.y.dtype in (torch.float32, torch.float64):
                    loss = F.mse_loss(logits[graph.train_mask].squeeze(), graph.y[graph.train_mask].float())
                else:
                    loss = F.cross_entropy(logits[graph.train_mask], graph.y[graph.train_mask])
                empirical_risk += weight * loss
            except Exception:
                continue

        if self.lipschitz_enabled:
            try:
                lipschitz_reg = self.compute_global_lipschitz_regularizer(model, mixed_graphs[0])
                lipschitz_reg = lipschitz_reg if lipschitz_reg is not None else 0.001
                lipschitz_term = lipschitz_reg * np.sqrt(self.r)
            except Exception:
                lipschitz_term = 0.0
        else:
            lipschitz_term = 0.0

        total_loss = empirical_risk + self.lambda_lip * lipschitz_term
        return total_loss, empirical_risk, lipschitz_term

    def _compute_fairness_aware_loss(self, logits: torch.Tensor, data: Data) -> torch.Tensor:
        if data.y.dtype in (torch.float32, torch.float64):
            base_loss = F.mse_loss(logits[data.train_mask].squeeze(), data.y[data.train_mask].float())
        else:
            base_loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])

        if hasattr(data, 's'):
            fairness_loss = self._compute_statistical_parity_loss(logits, data)
            total_loss = base_loss + self.beta * fairness_loss
        else:
            total_loss = base_loss
        return total_loss

    def _compute_statistical_parity_loss(self, logits: torch.Tensor, data: Data) -> torch.Tensor:
        s0_mask = (data.s == 0)
        s1_mask = (data.s == 1)
        if s0_mask.sum() > 0 and s1_mask.sum() > 0:
            if data.y.dtype in (torch.float32, torch.float64):
                s0_pred_mean = logits[s0_mask].squeeze().mean()
                s1_pred_mean = logits[s1_mask].squeeze().mean()
                sp_loss = abs(s0_pred_mean - s1_pred_mean)
            else:
                probs = F.softmax(logits, dim=1)
                s0_probs = probs[s0_mask].mean(dim=0)
                s1_probs = probs[s1_mask].mean(dim=0)
                sp_loss = torch.norm(s0_probs - s1_probs, p=1)
        else:
            sp_loss = torch.tensor(0.0, device=logits.device)
        return sp_loss

    def train(self, model: nn.Module, data: Data) -> Dict[str, Any]:
        self.setup_optimizers(model)
        model = model.to(self.device)
        data = data.to(self.device)

        train_losses = []
        val_accuracies = []
        lipschitz_values = []
        patience_counter = 0

        print(f"Starting GraphDRO training (mode: {self.mode})")
        print(f"Wasserstein radius r = {self.r}")
        print(f"Weight parameters κ = {self.kappa}")

        for epoch in range(self.num_epochs):
            model.train()
            mixed_graphs, weights = self.generate_mixed_empirical_distribution(data)

            try:
                total_loss, empirical_risk, lipschitz_reg = self.compute_graphdro_objective(model, mixed_graphs, weights)
            except Exception as e:
                print(f"Debug: Error in compute_graphdro_objective: {e}")
                print(f"Debug: Error type: {type(e)}")
                continue

            self.optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            self.optimizer.step()

            if self.scheduler is not None:
                self.scheduler.step()

            train_losses.append(total_loss.item())
            lipschitz_values.append(lipschitz_reg)

            if epoch % 5 == 0:
                val_acc = self.evaluate(model, data, split='val')
                val_accuracies.append(val_acc)

                if val_acc > self.best_val_acc:
                    self.best_val_acc = val_acc
                    self.best_model_state = copy.deepcopy(model.state_dict())
                    patience_counter = 0
                else:
                    patience_counter += 1

                if self.verbose:
                    print(f"Epoch {epoch:3d}: Loss={total_loss:.4f} "
                          f"(ERM={empirical_risk:.4f}, Lip={lipschitz_reg:.4f}), "
                          f"Val_Acc={val_acc:.4f}")

                if self.early_stop and patience_counter >= self.patience:
                    print(f"Early stopping at epoch {epoch}")
                    break

            if epoch % 10 == 0:
                torch.cuda.empty_cache()

        if self.best_model_state is not None:
            model.load_state_dict(self.best_model_state)

        final_results = self.evaluate_comprehensive(model, data)

        training_results = {
            'train_losses': train_losses,
            'val_accuracies': val_accuracies,
            'lipschitz_values': lipschitz_values,
            'best_val_acc': self.best_val_acc,
            'final_results': final_results,
            'config': self.config
        }
        return training_results

    def evaluate(self, model: nn.Module, data: Data, split: str = 'test') -> float:
        model.eval()
        with torch.no_grad():
            x_full = data.x
            logits = model(data.edge_index, x_full)
            mask = data.train_mask if split == 'train' else data.val_mask if split == 'val' else data.test_mask

            if data.y.dtype in (torch.float32, torch.float64):
                pred = logits[mask].squeeze()
                y_true = data.y[mask].float()
                mse = F.mse_loss(pred, y_true)
                ss_tot = ((y_true - y_true.mean()) ** 2).sum()
                ss_res = ((y_true - pred) ** 2).sum()
                r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
                acc = r2
            else:
                pred = logits[mask].argmax(dim=1)
                acc = (pred == data.y[mask]).float().mean()
        return acc.item()

    def evaluate_comprehensive(self, model: nn.Module, data: Data) -> Dict[str, Any]:
        model.eval()
        results = {}

        for split in ['train', 'val', 'test']:
            acc = self.evaluate(model, data, split)
            results[f'{split}_accuracy'] = acc

        if self.eval_fair and hasattr(data, 's'):
            fairness_metrics = self._evaluate_fairness(model, data)
            results.update(fairness_metrics)

        if self.eval_robust:
            robustness_metrics = self._evaluate_robustness(model, data)
            results.update(robustness_metrics)

        return results

    def _evaluate_fairness(self, model: nn.Module, data: Data) -> Dict[str, float]:
        with torch.no_grad():
            x_full = data.x
            logits = model(data.edge_index, x_full)
            if data.y.dtype in (torch.float32, torch.float64):
                pred = logits.squeeze()
            else:
                pred = logits.argmax(dim=1)

            s0_mask = (data.s == 0) & data.test_mask
            s1_mask = (data.s == 1) & data.test_mask

            if s0_mask.sum() > 0 and s1_mask.sum() > 0:
                if data.y.dtype in (torch.float32, torch.float64):
                    s0_pred_mean = pred[s0_mask].float().mean()
                    s1_pred_mean = pred[s1_mask].float().mean()
                    sp_violation = abs(s0_pred_mean - s1_pred_mean)

                    s0_pred_error = (pred[s0_mask] - data.y[s0_mask].float()).abs().mean()
                    s1_pred_error = (pred[s1_mask] - data.y[s1_mask].float()).abs().mean()
                    eo_violation = abs(s0_pred_error - s1_pred_error)
                else:
                    s0_pos_rate = (pred[s0_mask] == 1).float().mean()
                    s1_pos_rate = (pred[s1_mask] == 1).float().mean()
                    sp_violation = abs(s0_pos_rate - s1_pos_rate)

                    s0_y1_mask = s0_mask & (data.y == 1)
                    s1_y1_mask = s1_mask & (data.y == 1)
                    if s0_y1_mask.sum() > 0 and s1_y1_mask.sum() > 0:
                        s0_tpr = (pred[s0_y1_mask] == 1).float().mean()
                        s1_tpr = (pred[s1_y1_mask] == 1).float().mean()
                        eo_violation = abs(s0_tpr - s1_tpr)
                    else:
                        eo_violation = 0.0
            else:
                sp_violation = 0.0
                eo_violation = 0.0

        return {
            'sp_violation': sp_violation.item() if torch.is_tensor(sp_violation) else sp_violation,
            'eo_violation': eo_violation.item() if torch.is_tensor(eo_violation) else eo_violation
        }

    def _evaluate_robustness(self, model: nn.Module, data: Data) -> Dict[str, float]:
        perturbed_data = self.perturbation_model.combined_noise(data)
        clean_acc = self.evaluate(model, data, 'test')
        robust_acc = self.evaluate(model, perturbed_data, 'test')
        robustness_drop = clean_acc - robust_acc
        return {
            'clean_accuracy': clean_acc,
            'robust_accuracy': robust_acc,
            'robustness_drop': robustness_drop
        }

    def verify_theoretical_bounds(self, model: nn.Module, data: Data) -> Dict[str, Any]:
        L_theta = self.compute_global_lipschitz_regularizer(model, data)
        theoretical_bound = L_theta * np.sqrt(self.r)
        clean_acc = self.evaluate(model, data, 'test')
        perturbed_data = self.perturbation_model.combined_noise(data)
        robust_acc = self.evaluate(model, perturbed_data, 'test')
        actual_drop = clean_acc - robust_acc
        bound_satisfied = actual_drop <= theoretical_bound
        bound_gap = theoretical_bound - actual_drop
        return {
            'lipschitz_constant': L_theta,
            'theoretical_bound': theoretical_bound,
            'actual_robustness_drop': actual_drop,
            'bound_satisfied': bound_satisfied,
            'bound_gap': bound_gap
        }

    def analyze_perturbation_sensitivity(self, model: nn.Module, data: Data) -> Dict[str, Any]:
        sensitivity_results = {}
        epsilons = [0.01, 0.05, 0.1, 0.2, 0.3]
        for epsilon in epsilons:
            perturbed_data = self.perturbation_model.combined_noise(
                data, epsilon_e=epsilon, epsilon_x=epsilon,
                epsilon_l=epsilon, gamma=epsilon
            )
            acc = self.evaluate(model, perturbed_data, 'test')
            sensitivity_results[f'epsilon_{epsilon}'] = acc
        return sensitivity_results