import argparse
import torch
import yaml
import numpy as np
from typing import Dict, Any, List
import os
import json
from core.data import load_dataset
from core.models import create_model_from_config
from core.perturbations import NodeCentricPerturbations
from core.metrics import accuracy, sp, eo
from core.metrics import compute_fairness_metrics


class GraphDROEvaluator:
    def __init__(self, config: Dict[str, Any], model_path: str):
        self.config = config
        self.model_path = model_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.eval_config = config.get('evaluation', {})
        self.attack_config = self.eval_config.get('attacks', {})
        self.perturbations = NodeCentricPerturbations(config)
        self._load_data()
        self._load_model()

    def _load_data(self):
        print(" Loading dataset...")
        self.data, self.in_dim, self.out_dim = load_dataset(self.config)
        self.data = self.data.to(self.device)
        print(f" Data loaded: {self.data.num_nodes} nodes")

    def _load_model(self):
        print(" Loading model...")
        self.model = create_model_from_config(self.config, self.in_dim, self.out_dim)
        if os.path.exists(self.model_path):
            if self.model_path.endswith('.pt'):
                checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=False)
                if 'model_state_dict' in checkpoint:
                    self.model.load_state_dict(checkpoint['model_state_dict'])
                    print(f" Model loaded from checkpoint: Epoch {checkpoint.get('epoch', 'Unknown')}")
                else:
                    self.model.load_state_dict(checkpoint)
                    print(" Model loaded from weights file")
            else:
                raise ValueError(f"Unsupported model format: {self.model_path}")
        else:
            raise FileNotFoundError(f"Model file not found: {self.model_path}")
        self.model = self.model.to(self.device)
        self.model.eval()

    def evaluate_clean(self) -> Dict[str, float]:
        print(" Evaluating clean performance...")
        with torch.no_grad():
            logits = self.model(self.data.edge_index, self.data.x)
            preds = logits.argmax(1)
            results = {}
            for split in ['train', 'val', 'test']:
                mask = getattr(self.data, f'{split}_mask')
                if mask.sum() > 0:
                    acc = accuracy(preds[mask], self.data.y[mask])
                    sp_gap = sp(preds[mask], self.data.s[mask])
                    eo_gap = eo(preds[mask], self.data.y[mask], self.data.s[mask])
                    fairness_metrics = compute_fairness_metrics(
                        preds[mask], self.data.y[mask], self.data.s[mask]
                    )
                    results[f'{split}_acc'] = acc
                    results[f'{split}_sp'] = sp_gap
                    results[f'{split}_eo'] = eo_gap
                    results[f'{split}_dp'] = fairness_metrics['demographic_parity']
                    results[f'{split}_eop'] = fairness_metrics['equality_of_opportunity']
        return results

    def evaluate_robustness(self, attack_params: Dict[str, Any]) -> Dict[str, Dict[str, float]]:
        results = {}
        if attack_params.get('struct', False):
            rho = attack_params.get('rho', 3)
            print(f"Structural attack (rho={rho})...")
            attacked_data = self.perturbations.random_edge_attack(self.data, rho=rho)
            results['struct'] = self._evaluate_attacked_data(attacked_data, 'struct')
        if attack_params.get('feat', False):
            eps = attack_params.get('eps', 0.05)
            print(f"FGSM attack (eps={eps})...")
            attacked_data = self.perturbations._fgsm_attack(self.data, self.model)
            results['fgsm'] = self._evaluate_attacked_data(attacked_data, 'fgsm')
        if attack_params.get('sens', False):
            gamma = attack_params.get('gamma', 0.3)
            sens_idx = self.config.get('sensitive_idx', 0)
            print(f"Sensitive attribute flip (gamma={gamma})...")
            attacked_data = self.perturbations.simple_sensitive_flip_attack(
                self.data, gamma=gamma, sensitive_idx=sens_idx
            )
            results['sens_flip'] = self._evaluate_attacked_data(attacked_data, 'sens_flip')
        if attack_params.get('label', False):
            eta = attack_params.get('eta', 0.2)
            print(f"Label flip (eta={eta})...")
            attacked_data = self.perturbations.simple_label_flip_attack(self.data, eta=eta)
            results['label_flip'] = self._evaluate_attacked_data(attacked_data, 'label_flip')
        return results

    def _evaluate_attacked_data(self, attacked_data, attack_name: str) -> Dict[str, float]:
        with torch.no_grad():
            logits = self.model(attacked_data.edge_index, attacked_data.x)
            preds = logits.argmax(1)
            mask = attacked_data.test_mask
            acc = accuracy(preds[mask], attacked_data.y[mask])
            sp_gap = sp(preds[mask], attacked_data.s[mask])
            eo_gap = eo(preds[mask], attacked_data.y[mask], attacked_data.s[mask])
            fairness_metrics = compute_fairness_metrics(
                preds[mask], attacked_data.y[mask], attacked_data.s[mask]
            )
            return {
                'acc': acc,
                'sp': sp_gap,
                'eo': eo_gap,
                'dp': fairness_metrics['demographic_parity'],
                'eop': fairness_metrics['equality_of_opportunity']
            }

    def evaluate_comprehensive(self, attack_params: Dict[str, Any]) -> Dict[str, Any]:
        print(" Starting comprehensive evaluation...")
        print("=" * 60)
        clean_results = self.evaluate_clean()
        robust_results = self.evaluate_robustness(attack_params)
        robustness_degradation = {}
        clean_test_acc = clean_results['test_acc']
        for attack_type, attack_results in robust_results.items():
            attack_acc = attack_results['acc']
            degradation = (clean_test_acc - attack_acc) * 100
            robustness_degradation[attack_type] = {
                'acc_drop': degradation,
                'relative_drop': degradation / (clean_test_acc * 100) if clean_test_acc > 0 else 0
            }
        comprehensive_results = {
            'clean': clean_results,
            'attacked': robust_results,
            'robustness_degradation': robustness_degradation,
            'summary': self._compute_summary_metrics(clean_results, robust_results)
        }
        return comprehensive_results

    def _compute_summary_metrics(self, clean_results: Dict[str, float],
                                robust_results: Dict[str, Dict[str, float]]) -> Dict[str, float]:
        summary = {}
        summary['clean_test_acc'] = clean_results['test_acc']
        summary['clean_fairness'] = abs(clean_results['test_sp']) + abs(clean_results['test_eo'])
        if robust_results:
            avg_robust_acc = np.mean([results['acc'] for results in robust_results.values()])
            avg_robust_fairness = np.mean([
                abs(results['sp']) + abs(results['eo'])
                for results in robust_results.values()
            ])
            summary['avg_robust_acc'] = avg_robust_acc
            summary['avg_robust_fairness'] = avg_robust_fairness
            summary['avg_acc_drop'] = (clean_results['test_acc'] - avg_robust_acc) * 100
        fairness_penalty = summary['clean_fairness']
        robustness_penalty = summary.get('avg_acc_drop', 0) / 100
        summary['comprehensive_score'] = (
            summary['clean_test_acc'] -
            0.3 * fairness_penalty -
            0.2 * robustness_penalty
        )
        return summary

    def print_results(self, results: Dict[str, Any]):
        print("\n" + "=" * 60)
        print(" Evaluation Results Summary")
        print("=" * 60)
        clean = results['clean']
        print(f"\n Clean Performance:")
        print(f"  Test Accuracy: {clean['test_acc']:.4f}")
        print(f"  Statistical Parity: {abs(clean['test_sp']):.4f}")
        print(f"  Equality of Opportunity: {abs(clean['test_eo']):.4f}")
        print(f"  Fairness Score: {1 - (abs(clean['test_sp']) + abs(clean['test_eo']))/2:.4f}")
        if results['attacked']:
            print(f"\n Robustness Evaluation:")
            for attack_type, attack_results in results['attacked'].items():
                degradation = results['robustness_degradation'][attack_type]
                print(f"  {attack_type:12s}: "
                      f"Acc={attack_results['acc']:.4f} "
                      f"(↓{degradation['acc_drop']:.2f}%), "
                      f"|SP|={abs(attack_results['sp']):.4f}, "
                      f"|EO|={abs(attack_results['eo']):.4f}")
        summary = results['summary']
        print(f"\n Summary Metrics:")
        print(f"  Comprehensive Score:     {summary['comprehensive_score']:.4f}")
        print(f"  Average Accuracy:   {summary.get('avg_robust_acc', summary['clean_test_acc']):.4f}")
        print(f"  Average Fairness:   {summary.get('avg_robust_fairness', summary['clean_fairness']):.4f}")
        if 'avg_acc_drop' in summary:
            print(f"  Average Accuracy Drop: {summary['avg_acc_drop']:.2f}%")
        print("=" * 60)

    def save_results(self, results: Dict[str, Any], save_path: str):
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        def convert_numpy(obj):
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, dict):
                return {key: convert_numpy(value) for key, value in obj.items()}
            elif isinstance(obj, list):
                return [convert_numpy(item) for item in obj]
            else:
                return obj
        results_json = convert_numpy(results)
        with open(save_path, 'w') as f:
            json.dump(results_json, f, indent=2)
        print(f" Evaluation results saved: {save_path}")


def main(args):
    config = yaml.safe_load(open(args.config))
    evaluator = GraphDROEvaluator(config, args.ckpt)
    attack_params = {
        'struct': args.struct,
        'rho': args.rho,
        'feat': args.feat,
        'eps': args.eps,
        'sens': args.sens,
        'gamma': args.gamma,
        'label': args.label,
        'eta': args.eta
    }
    results = evaluator.evaluate_comprehensive(attack_params)
    evaluator.print_results(results)
    if args.save_results:
        save_path = args.save_results
        evaluator.save_results(results, save_path)
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='GraphDRO Evaluation')
    parser.add_argument("--ckpt", required=True, help="Model checkpoint path")
    parser.add_argument("--config", default="configs/pokec.yaml", help="Configuration file path")
    parser.add_argument("--struct", action="store_true", help="Enable structural attack")
    parser.add_argument("--rho", type=int, default=3, help="Structural attack strength")
    parser.add_argument("--feat", action="store_true", help="Enable feature attack")
    parser.add_argument("--eps", type=float, default=0.05, help="FGSM attack strength")
    parser.add_argument("--sens", action="store_true", help="Enable sensitive attribute attack")
    parser.add_argument("--gamma", type=float, default=0.3, help="Sensitive attribute flip ratio")
    parser.add_argument("--label", action="store_true", help="Enable label attack")
    parser.add_argument("--eta", type=float, default=0.2, help="Label flip ratio")
    parser.add_argument("--save-results", help="Path to save results")
    parser.add_argument("--all-attacks", action="store_true", help="Enable all attacks")
    args = parser.parse_args()
    if args.all_attacks:
        args.struct = True
        args.feat = True
        args.sens = True
        args.label = True
    main(args)