from torch.nn import functional as F
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
import os
import sys
import json
import argparse
import matplotlib.pyplot as plt
import time
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import set_seed, setup_logger
from datasets.load_datasets import load_dataset
from models.resnet import ResNet9, CifarResNet18
from models.lenet import LeNet

def entropy(p, dim=-1, keepdim=False):
    return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)


def collect_prob(data_loader, model):
    data_loader = torch.utils.data.DataLoader(
        data_loader.dataset, batch_size=1, shuffle=False
    )
    prob = []
    with torch.no_grad():
        for batch in data_loader:
            if len(batch) == 2:
                data, target = batch
                data = data.to(next(model.parameters()).device)
            else:
                batch = [tensor.to(next(model.parameters()).device)
                         for tensor in batch]
                data, _, target = batch

            output = model(data)
            prob.append(F.softmax(output, dim=-1).data)
    return torch.cat(prob)


def get_membership_attack_data(retain_loader, forget_loader, test_loader, model):
    retain_prob = collect_prob(retain_loader, model)
    forget_prob = collect_prob(forget_loader, model)
    test_prob = collect_prob(test_loader, model)

    X_r = (
        torch.cat([entropy(retain_prob), entropy(test_prob)])
        .cpu()
        .numpy()
        .reshape(-1, 1)
    )
    Y_r = np.concatenate([np.ones(len(retain_prob)), np.zeros(len(test_prob))])

    X_f = entropy(forget_prob).cpu().numpy().reshape(-1, 1)
    Y_f = np.concatenate([np.ones(len(forget_prob))])
    return X_f, Y_f, X_r, Y_r


def get_membership_attack_prob(retain_loader, forget_loader, test_loader, model):
    X_f, Y_f, X_r, Y_r = get_membership_attack_data(
        retain_loader, forget_loader, test_loader, model
    )
    clf = LogisticRegression(
        class_weight="balanced", solver="lbfgs", multi_class="multinomial"
    )
    clf.fit(X_r, Y_r)
    results = clf.predict(X_f)
    return results.mean()


def split_dataset_by_class(dataset, forget_class=0):
    """
    Split dataset into forget set and retain set by class

    Args:
        dataset: Original dataset
        forget_class: Class to be forgotten

    Returns:
        forget_dataset: Forget set (samples of specified class)
        retain_dataset: Retain set (samples not in specified class)
    """
    forget_indices = []
    retain_indices = []

    for idx, (_, target) in enumerate(dataset):
        if target == forget_class:
            forget_indices.append(idx)
        else:
            retain_indices.append(idx)

    forget_dataset = torch.utils.data.Subset(dataset, forget_indices)
    retain_dataset = torch.utils.data.Subset(dataset, retain_indices)

    return forget_dataset, retain_dataset


class MembershipInferenceAttack:
    """
    Membership inference attack class

    Analyzes model output distributions to infer whether samples were in the training set,
    used to evaluate model privacy leakage risk and unlearning effectiveness.
    """

    def __init__(self, model_path, dataset_name='mnist', forget_class=0,
                 device=None, log_dir="logs/attack", seed=42):
        """
        Initialize membership inference attack

        Args:
            model_path: Model weight file path
            dataset_name: Dataset name
            forget_class: Forget class
            device: Computing device
            log_dir: Log save directory
            seed: Random seed
        """
        set_seed(seed)

        self.device = device or torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.dataset_name = dataset_name
        self.forget_class = forget_class
        self.model_path = model_path

        model_name = os.path.basename(model_path).replace('.pth', '')
        self.log_dir = os.path.join(log_dir, dataset_name, model_name)
        os.makedirs(self.log_dir, exist_ok=True)

        self.logger, self.log_file = setup_logger(
            f"membership_attack_{model_name}", self.log_dir)

        self.logger.info(f"Initializing membership inference attack")
        self.logger.info(f"Model path: {model_path}")
        self.logger.info(f"Dataset: {self.dataset_name}")
        self.logger.info(f"Forget class: {self.forget_class}")
        self.logger.info(f"Device: {self.device}")

        self.retain_loader, self.forget_loader, self.test_loader = self._prepare_datasets()

        self.model = self._load_model(model_path)

        self.logger.info(f"Forget set size: {len(self.forget_loader.dataset)}")
        self.logger.info(f"Retain set size: {len(self.retain_loader.dataset)}")
        self.logger.info(f"Test set size: {len(self.test_loader.dataset)}")

    def _prepare_datasets(self):
        """Prepare datasets: split into forget set, retain set and test set"""
        train_loader, test_loader, num_classes = load_dataset(
            self.dataset_name, batch_size=128, num_workers=4
        )

        forget_dataset, retain_dataset = split_dataset_by_class(
            train_loader.dataset, self.forget_class
        )

        forget_loader = torch.utils.data.DataLoader(
            forget_dataset, batch_size=128, shuffle=False, num_workers=4
        )
        retain_loader = torch.utils.data.DataLoader(
            retain_dataset, batch_size=128, shuffle=False, num_workers=4
        )

        return retain_loader, forget_loader, test_loader

    def _get_model_info_from_path(self, model_path):
        """Infer model architecture and dataset information from model path"""
        filename = os.path.basename(model_path).lower()

        if 'lenet' in filename:
            model_type = 'lenet'
        elif 'resnet9' in filename:
            model_type = 'resnet9'
        elif 'cifarresnet18' in filename or 'resnet18' in filename:
            model_type = 'resnet18'
        else:
            if self.dataset_name == 'mnist':
                model_type = 'lenet'
            else:
                model_type = 'resnet9'

        if self.dataset_name == 'mnist':
            in_channels = 1
        else:
            in_channels = 3

        if self.dataset_name == 'mnist':
            num_classes = 10
        elif self.dataset_name == 'cifar10':
            num_classes = 10
        elif self.dataset_name == 'cifar100':
            num_classes = 100
        elif self.dataset_name == 'svhn':
            num_classes = 10
        else:
            num_classes = 10

        return model_type, in_channels, num_classes

    def _create_model(self, model_type, in_channels, num_classes):
        """Create model instance based on model type"""
        if model_type == 'lenet':
            return LeNet(num_classes=num_classes, in_channels=in_channels)
        elif model_type == 'resnet9':
            return ResNet9(num_classes=num_classes, in_channels=in_channels)
        elif model_type == 'resnet18':
            if num_classes == 100:
                return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
            else:
                return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

    def _load_model(self, model_path):
        """Load model"""
        model_type, in_channels, num_classes = self._get_model_info_from_path(
            model_path)

        self.logger.info(
            f"Inferred model info: type={model_type}, input_channels={in_channels}, num_classes={num_classes}")

        model = self._create_model(model_type, in_channels, num_classes)

        try:
            checkpoint = torch.load(
                model_path, map_location=self.device, weights_only=True)

            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                elif 'state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['state_dict'])
                else:
                    model.load_state_dict(checkpoint)
            else:
                model.load_state_dict(checkpoint)

            model = model.to(self.device)
            model.eval()

            self.logger.info(f"Successfully loaded model: {model_path}")
            return model

        except Exception as e:
            self.logger.error(f"Failed to load model: {e}")
            raise

    def run_attack(self):
        """
        Execute membership inference attack

        Returns:
            dict: Attack results including attack success rate and other metrics
        """
        self.logger.info("Starting membership inference attack")
        start_time = time.time()

        X_f, Y_f, X_r, Y_r = get_membership_attack_data(
            self.retain_loader, self.forget_loader, self.test_loader, self.model
        )

        clf = LogisticRegression(
            class_weight="balanced", solver="lbfgs", multi_class="multinomial"
        )
        clf.fit(X_r, Y_r)

        forget_predictions = clf.predict(X_f)
        forget_probabilities = clf.predict_proba(X_f)[:, 1]

        attack_success_rate = forget_predictions.mean()

        avg_membership_prob = forget_probabilities.mean()
        max_membership_prob = forget_probabilities.max()
        min_membership_prob = forget_probabilities.min()

        elapsed = time.time() - start_time

        self.logger.info(f"Membership inference attack completed (time: {elapsed:.2f}s)")
        self.logger.info(f"Attack success rate: {attack_success_rate:.4f}")
        self.logger.info(f"Average membership probability: {avg_membership_prob:.4f}")
        self.logger.info(f"Maximum membership probability: {max_membership_prob:.4f}")
        self.logger.info(f"Minimum membership probability: {min_membership_prob:.4f}")

        return {
            'attack_success_rate': attack_success_rate,
            'avg_membership_prob': avg_membership_prob,
            'max_membership_prob': max_membership_prob,
            'min_membership_prob': min_membership_prob,
            'forget_predictions': forget_predictions,
            'forget_probabilities': forget_probabilities,
            'classifier': clf
        }


def compare_multiple_membership_inference(model_configs, dataset_name='mnist',
                                          forget_class=0, save_dir="results/attack"):
    """
    Compare membership inference attack results across multiple models

    Args:
        model_configs: Model configuration list, format: [{'path': 'path', 'label': 'label'}, ...]
        dataset_name: Dataset name
        forget_class: Forget class
        save_dir: Result save directory
    """
    os.makedirs(save_dir, exist_ok=True)

    attackers = []
    attack_results = []

    for i, config in enumerate(model_configs):
        model_path = config['path']
        model_label = config['label']

        print(f"  Processing {model_label}...")

        try:
            attacker = MembershipInferenceAttack(
                model_path, dataset_name, forget_class,
                log_dir=os.path.join(save_dir, "logs")
            )
            attackers.append(attacker)

            result = attacker.run_attack()
            attack_results.append(result)
        except Exception as e:
            print(f"    Error: Failed to process {model_label} - {e}")
            continue

    if attack_results:
        _create_membership_comparison_visualization(
            attack_results, forget_class,
            [config['label']
                for config in model_configs[:len(attack_results)]], save_dir
        )
        
        csv_path = os.path.join(save_dir, f"membership_inference_results_{dataset_name}_class_{forget_class}.csv")
        with open(csv_path, 'w') as f:
            f.write("Model,Attack_Success_Rate,Avg_Membership_Prob,Max_Membership_Prob,Min_Membership_Prob\n")
            
            for i, result in enumerate(attack_results):
                model_label = model_configs[i]['label']
                f.write(f"{model_label},{result['attack_success_rate']:.6f},{result['avg_membership_prob']:.6f},"
                        f"{result['max_membership_prob']:.6f},{result['min_membership_prob']:.6f}\n")
        
        print(f"  Membership inference attack results saved to: {csv_path}")

    result_dict = {}
    for i, config in enumerate(model_configs[:len(attack_results)]):
        result_dict[config['label']] = attack_results[i]

    return result_dict


def _create_membership_comparison_visualization(attack_results, forget_class,
                                                model_labels, save_dir):
    """Create membership inference attack comparison visualization"""
    num_models = len(attack_results)

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    success_rates = [result['attack_success_rate']
                     for result in attack_results]
    bars1 = axes[0].bar(model_labels, success_rates, alpha=0.7, color=[
                        'blue', 'orange', 'green'][:num_models])
    axes[0].set_title('Attack Success Rate Comparison')
    axes[0].set_ylabel('Attack Success Rate')
    axes[0].set_ylim(0, 1)

    for i, (bar, rate) in enumerate(zip(bars1, success_rates)):
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                     f'{rate:.3f}', ha='center', va='bottom')

    avg_probs = [result['avg_membership_prob'] for result in attack_results]
    bars2 = axes[1].bar(model_labels, avg_probs, alpha=0.7, color=[
                        'blue', 'orange', 'green'][:num_models])
    axes[1].set_title('Average Membership Probability Comparison')
    axes[1].set_ylabel('Average Membership Probability')
    axes[1].set_ylim(0, 1)

    for i, (bar, prob) in enumerate(zip(bars2, avg_probs)):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                     f'{prob:.3f}', ha='center', va='bottom')

    plt.suptitle(f'Membership Inference Attack Comparison - Forget Class: {forget_class}')
    plt.tight_layout()

    save_path = os.path.join(
        save_dir, f"membership_inference_comparison_class_{forget_class}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def compare_membership_inference(original_model_path, unlearned_model_path,
                                 dataset_name='mnist', forget_class=0,
                                 save_dir="results/attack"):
    """
    Compare membership inference attack results between original and unlearned models (backward compatibility)
    """
    model_configs = [
        {'path': original_model_path, 'label': 'Original Model'},
        {'path': unlearned_model_path, 'label': 'Unlearned Model'}
    ]

    results = compare_multiple_membership_inference(
        model_configs, dataset_name, forget_class, save_dir
    )

    original_result = results.get('Original Model', {})
    unlearned_result = results.get('Unlearned Model', {})
    success_rate_drop = original_result.get('attack_success_rate', 0) - \
        unlearned_result.get('attack_success_rate', 0)
    membership_prob_drop = original_result.get('avg_membership_prob', 0) - \
        unlearned_result.get('avg_membership_prob', 0)

    return {
        'original': original_result,
        'unlearned': unlearned_result,
        'success_rate_drop': success_rate_drop,
        'membership_prob_drop': membership_prob_drop
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Execute membership inference attack evaluation')
    parser.add_argument('--config', type=str, default='attack/model_configs.json',
                        help='Model configuration file path (JSON)')
    parser.add_argument('--model-path', type=str, 
                        help='Single model path (for single model attack test)')
    parser.add_argument('--dataset', type=str, default='mnist',
                        help='Dataset name (default: mnist)')
    parser.add_argument('--forget-class', type=int, default=0,
                        help='Class to forget (default: 0)')
    parser.add_argument('--output', type=str, default='results/membership_attack_comparison_summary.csv',
                        help='Result output CSV file path')
    parser.add_argument('--quiet', action='store_true',
                        help='Quiet mode, reduce output information')
    args = parser.parse_args()
    
    if args.model_path:
        if not args.quiet:
            print(f"Executing single model membership inference attack test")
            print(f"Model path: {args.model_path}")
            print(f"Dataset: {args.dataset}")
            print(f"Forget class: {args.forget_class}")
        
        try:
            attacker = MembershipInferenceAttack(
                model_path=args.model_path,
                dataset_name=args.dataset,
                forget_class=args.forget_class,
                log_dir=f"results/single_model_attack_{args.dataset}"
            )
            
            result = attacker.run_attack()
            
            single_result_path = args.output.replace('.csv', '_single_model.csv')
            os.makedirs(os.path.dirname(single_result_path), exist_ok=True)
            
            with open(single_result_path, 'w') as f:
                f.write("Model_Path,Dataset,Forget_Class,Attack_Success_Rate,Avg_Membership_Prob,Max_Membership_Prob,Min_Membership_Prob\n")
                f.write(f"{args.model_path},{args.dataset},{args.forget_class},"
                       f"{result['attack_success_rate']:.6f},{result['avg_membership_prob']:.6f},"
                       f"{result['max_membership_prob']:.6f},{result['min_membership_prob']:.6f}\n")
            
            print(f"\nSingle model attack results:")
            print(f"  Attack success rate: {result['attack_success_rate']:.4f}")
            print(f"  Average membership probability: {result['avg_membership_prob']:.4f}")
            print(f"  Maximum membership probability: {result['max_membership_prob']:.4f}")
            print(f"  Minimum membership probability: {result['min_membership_prob']:.4f}")
            print(f"  Detailed results saved to: {single_result_path}")
            
        except Exception as e:
            print(f"Single model attack failed: {e}")
            sys.exit(1)
        
        sys.exit(0)
    
    model_configs = []
    config_path = os.path.join(os.path.dirname(
        os.path.dirname(os.path.abspath(__file__))), args.config)

    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            config_data = json.load(f)

            if 'models' in config_data:
                if 'dataset' in config_data['models'][0]:
                    for dataset_group in config_data['models']:
                        dataset_name = dataset_group['dataset']
                        for model in dataset_group['models']:
                            model_config = model.copy()
                            model_config['dataset'] = dataset_name
                            model_configs.append(model_config)
                elif 'type' in config_data['models'][0]:
                    for method_group in config_data['models']:
                        method_type = method_group['type']
                        for model in method_group['models']:
                            model_config = model.copy()
                            model_config['type'] = method_type
                            model_configs.append(model_config)
            else:
                model_configs = config_data

        if not args.quiet:
            print(f"Loaded {len(model_configs)} model configurations from {config_path}")
    else:
        print(f"Warning: Configuration file {config_path} does not exist")
        if not args.quiet:
            print("Hint: Use --model-path parameter for single model attack test")
        sys.exit(1)
        
    if len(model_configs) < 1:
        print("Error: At least one model configuration required")
        print("Hint: Use --model-path parameter for single model attack test")
        sys.exit(1)
        
    method_types = set(config.get('type', 'unknown') for config in model_configs)
    datasets = set(config.get('dataset', 'unknown') for config in model_configs)
    all_results = {}
    
    for dataset_name in datasets:
        if dataset_name == 'unknown':
            continue
            
        if not args.quiet:
            print(f"\nProcessing dataset: {dataset_name.upper()}")
            
        architectures = set()
        for config in model_configs:
            if config.get('dataset') == dataset_name:
                label = config.get('label', '')
                if 'ResNet9' in label:
                    architectures.add('ResNet9')
                elif 'LeNet' in label:
                    architectures.add('LeNet')
                elif 'CifarResNet18' in label or 'ResNet18' in label:
                    architectures.add('ResNet18')
                    
        for arch in architectures:
            if not args.quiet:
                print(f"  Processing architecture: {arch}")
                
            arch_configs = {}
            for method_type in method_types:
                if method_type == 'unknown':
                    continue
                    
                for config in model_configs:
                    if (config.get('dataset') == dataset_name and 
                        config.get('type') == method_type and
                        arch in config.get('label', '')):
                        arch_configs[method_type] = config
                        break
                        
            if len(arch_configs) < 1:
                if not args.quiet:
                    print(f"    Skipping {arch}, no models found")
                continue
                
            compare_configs = []
            for method_type, config in arch_configs.items():
                compare_configs.append({
                    'path': config['path'],
                    'label': config['label']
                })
                
            try:
                result_key = f"{dataset_name}_{arch}"
                results = compare_multiple_membership_inference(
                    model_configs=compare_configs,
                    dataset_name=dataset_name,
                    forget_class=args.forget_class,
                    save_dir=f"results/method_comparison_{dataset_name}_{arch}"
                )
                
                all_results[result_key] = {
                    'dataset': dataset_name,
                    'architecture': arch,
                    'results': results
                }
                
                if not args.quiet:
                    print(f"    {result_key} attack completed, processed {len(results)} models")
                    
            except Exception as e:
                print(f"    Error processing {result_key}: {e}")
                continue

    if not all_results:
        print("Error: No processable model configurations found")
        print("Hint: Use --model-path parameter for single model attack test")
        sys.exit(1)

    csv_path = args.output
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)

    with open(csv_path, 'w') as f:
        f.write("Method,Dataset,Architecture,Attack_Success_Rate,Avg_Membership_Prob,Max_Membership_Prob,Min_Membership_Prob\n")

        for result_key, result_data in all_results.items():
            dataset_name = result_data['dataset']
            architecture = result_data['architecture']
            results = result_data['results']
            
            for label, attack_result in results.items():
                if 'original' in label.lower():
                    method_type = 'original'
                elif 'forget' in label.lower():
                    method_type = 'forget'
                elif 'retrain' in label.lower():
                    method_type = 'retrain'
                else:
                    method_type = 'unknown'
                
                attack_rate = attack_result.get('attack_success_rate', 0)
                avg_prob = attack_result.get('avg_membership_prob', 0)
                max_prob = attack_result.get('max_membership_prob', 0)
                min_prob = attack_result.get('min_membership_prob', 0)
                
                f.write(f"{method_type},{dataset_name},{architecture},{attack_rate:.6f},{avg_prob:.6f},{max_prob:.6f},{min_prob:.6f}\n")

    if len(all_results) > 1 or any(len(result_data['results']) > 1 for result_data in all_results.values()):
        comparison_csv_path = os.path.join(os.path.dirname(csv_path), "membership_attack_method_comparison.csv")
        with open(comparison_csv_path, 'w') as f:
            f.write("Dataset,Architecture,Original_Attack_Rate,Forget_Attack_Rate,Retrain_Attack_Rate,")
            f.write("Forget_Drop,Retrain_Drop,Original_Prob,Forget_Prob,Retrain_Prob,")
            f.write("Forget_Prob_Drop,Retrain_Prob_Drop\n")
            
            for result_key, result_data in all_results.items():
                dataset = result_data['dataset']
                arch = result_data['architecture']
                results = result_data['results']
                
                original_result = None
                forget_result = None
                retrain_result = None
                
                for label, attack_result in results.items():
                    if 'original' in label.lower():
                        original_result = attack_result
                    elif 'forget' in label.lower():
                        forget_result = attack_result
                    elif 'retrain' in label.lower():
                        retrain_result = attack_result
                
                if original_result and (forget_result or retrain_result):
                    orig_rate = original_result.get('attack_success_rate', 0)
                    forget_rate = forget_result.get('attack_success_rate', 0) if forget_result else 0
                    retrain_rate = retrain_result.get('attack_success_rate', 0) if retrain_result else 0
                    
                    orig_prob = original_result.get('avg_membership_prob', 0)
                    forget_prob = forget_result.get('avg_membership_prob', 0) if forget_result else 0
                    retrain_prob = retrain_result.get('avg_membership_prob', 0) if retrain_result else 0
                    
                    forget_drop = orig_rate - forget_rate if forget_result else 0
                    retrain_drop = orig_rate - retrain_rate if retrain_result else 0
                    forget_prob_drop = orig_prob - forget_prob if forget_result else 0
                    retrain_prob_drop = orig_prob - retrain_prob if retrain_result else 0
                    
                    f.write(f"{dataset},{arch},{orig_rate:.6f},{forget_rate:.6f},{retrain_rate:.6f},")
                    f.write(f"{forget_drop:.6f},{retrain_drop:.6f},{orig_prob:.6f},{forget_prob:.6f},")
                    f.write(f"{retrain_prob:.6f},{forget_prob_drop:.6f},{retrain_prob_drop:.6f}\n")

        print(f"Method comparison results saved to: {comparison_csv_path}")

    print(f"Membership inference attack results saved to: {csv_path}")