from torch.nn import functional as F
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
import os
import sys
import json
import argparse
import matplotlib.pyplot as plt
import time
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import set_seed, setup_logger
from datasets.load_datasets import load_dataset
from models.resnet import ResNet9, CifarResNet18
from models.lenet import LeNet

def entropy(p, dim=-1, keepdim=False):
    return -torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=dim, keepdim=keepdim)


def collect_prob(data_loader, model):
    data_loader = torch.utils.data.DataLoader(
        data_loader.dataset, batch_size=1, shuffle=False
    )
    prob = []
    with torch.no_grad():
        for batch in data_loader:
            if len(batch) == 2:
                data, target = batch
                data = data.to(next(model.parameters()).device)
            else:
                batch = [tensor.to(next(model.parameters()).device)
                         for tensor in batch]
                data, _, target = batch

            output = model(data)
            prob.append(F.softmax(output, dim=-1).data)
    return torch.cat(prob)


def get_membership_attack_data(retain_loader, forget_loader, test_loader, model):
    retain_prob = collect_prob(retain_loader, model)
    forget_prob = collect_prob(forget_loader, model)
    test_prob = collect_prob(test_loader, model)

    X_r = (
        torch.cat([entropy(retain_prob), entropy(test_prob)])
        .cpu()
        .numpy()
        .reshape(-1, 1)
    )
    Y_r = np.concatenate([np.ones(len(retain_prob)), np.zeros(len(test_prob))])

    X_f = entropy(forget_prob).cpu().numpy().reshape(-1, 1)
    Y_f = np.concatenate([np.ones(len(forget_prob))])
    return X_f, Y_f, X_r, Y_r


def get_membership_attack_prob(retain_loader, forget_loader, test_loader, model):
    X_f, Y_f, X_r, Y_r = get_membership_attack_data(
        retain_loader, forget_loader, test_loader, model
    )
    clf = LogisticRegression(
        class_weight="balanced", solver="lbfgs", multi_class="multinomial"
    )
    clf.fit(X_r, Y_r)
    results = clf.predict(X_f)
    return results.mean()


def split_dataset_by_class(dataset, forget_class=0):
    """
    Split dataset into forget set and retain set based on class labels

    Args:
        dataset: Original dataset
        forget_class: Class to be forgotten

    Returns:
        forget_dataset: Forget set (contains samples of specified class)
        retain_dataset: Retain set (does not contain samples of specified class)
    """
    forget_indices = []
    retain_indices = []

    for idx, (_, target) in enumerate(dataset):
        if target == forget_class:
            forget_indices.append(idx)
        else:
            retain_indices.append(idx)

    forget_dataset = torch.utils.data.Subset(dataset, forget_indices)
    retain_dataset = torch.utils.data.Subset(dataset, retain_indices)

    return forget_dataset, retain_dataset


class MembershipInferenceAttack:
    """
    Membership Inference Attack class

    Analyzes model's output distribution on samples to infer whether samples 
    were in the training set, used to evaluate model's privacy leakage risk 
    and unlearning effectiveness.
    """

    def __init__(self, model_path, dataset_name='mnist', forget_class=0,
                 device=None, log_dir="logs/attack", seed=42):
        """
        Initialize membership inference attack

        Args:
            model_path: Path to model weights file
            dataset_name: Name of dataset
            forget_class: Class to forget
            device: Computing device
            log_dir: Directory to save logs
            seed: Random seed
        """
        set_seed(seed)

        self.device = device or torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.dataset_name = dataset_name
        self.forget_class = forget_class
        self.model_path = model_path

        model_name = os.path.basename(model_path).replace('.pth', '')
        self.log_dir = os.path.join(log_dir, dataset_name, model_name)
        os.makedirs(self.log_dir, exist_ok=True)

        self.logger, self.log_file = setup_logger(
            f"membership_attack_{model_name}", self.log_dir)

        self.logger.info(f"Initialize membership inference attack")
        self.logger.info(f"Model path: {model_path}")
        self.logger.info(f"Dataset: {self.dataset_name}")
        self.logger.info(f"Forget class: {self.forget_class}")
        self.logger.info(f"Device: {self.device}")

        self.retain_loader, self.forget_loader, self.test_loader = self._prepare_datasets()

        self.model = self._load_model(model_path)

        self.logger.info(f"Forget set size: {len(self.forget_loader.dataset)}")
        self.logger.info(f"Retain set size: {len(self.retain_loader.dataset)}")
        self.logger.info(f"Test set size: {len(self.test_loader.dataset)}")

    def _prepare_datasets(self):
        """Prepare datasets: split into forget set, retain set and test set"""
        train_loader, test_loader, num_classes = load_dataset(
            self.dataset_name, batch_size=128, num_workers=4
        )

        forget_dataset, retain_dataset = split_dataset_by_class(
            train_loader.dataset, self.forget_class
        )

        forget_loader = torch.utils.data.DataLoader(
            forget_dataset, batch_size=128, shuffle=False, num_workers=4
        )
        retain_loader = torch.utils.data.DataLoader(
            retain_dataset, batch_size=128, shuffle=False, num_workers=4
        )

        return retain_loader, forget_loader, test_loader

    def _get_model_info_from_path(self, model_path):
        """Infer model architecture and dataset info from model path"""
        filename = os.path.basename(model_path).lower()

        if 'lenet' in filename:
            model_type = 'lenet'
        elif 'resnet9' in filename:
            model_type = 'resnet9'
        elif 'cifarresnet18' in filename or 'resnet18' in filename:
            model_type = 'resnet18'
        else:
            if self.dataset_name == 'mnist':
                model_type = 'lenet'
            else:
                model_type = 'resnet9'

        if self.dataset_name == 'mnist':
            in_channels = 1
        else:
            in_channels = 3

        if self.dataset_name == 'mnist':
            num_classes = 10
        elif self.dataset_name == 'cifar10':
            num_classes = 10
        elif self.dataset_name == 'cifar100':
            num_classes = 100
        elif self.dataset_name == 'svhn':
            num_classes = 10
        else:
            num_classes = 10

        return model_type, in_channels, num_classes

    def _create_model(self, model_type, in_channels, num_classes):
        """Create model instance based on model type"""
        if model_type == 'lenet':
            return LeNet(num_classes=num_classes, in_channels=in_channels)
        elif model_type == 'resnet9':
            return ResNet9(num_classes=num_classes, in_channels=in_channels)
        elif model_type == 'resnet18':
            if num_classes == 100:
                return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
            else:
                return CifarResNet18(num_classes=num_classes, in_channels=in_channels)
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

    def _load_model(self, model_path):
        """Load model"""
        model_type, in_channels, num_classes = self._get_model_info_from_path(
            model_path)

        self.logger.info(
            f"Inferred model info: type={model_type}, input_channels={in_channels}, num_classes={num_classes}")

        model = self._create_model(model_type, in_channels, num_classes)

        try:
            checkpoint = torch.load(
                model_path, map_location=self.device, weights_only=True)

            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                elif 'state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['state_dict'])
                else:
                    model.load_state_dict(checkpoint)
            else:
                model.load_state_dict(checkpoint)

            model = model.to(self.device)
            model.eval()

            self.logger.info(f"Successfully loaded model: {model_path}")
            return model

        except Exception as e:
            self.logger.error(f"Failed to load model: {e}")
            raise

    def run_attack(self):
        """
        Execute membership inference attack

        Returns:
            dict: Attack results including attack success rate and other metrics
        """
        self.logger.info("Start executing membership inference attack")
        start_time = time.time()

        X_f, Y_f, X_r, Y_r = get_membership_attack_data(
            self.retain_loader, self.forget_loader, self.test_loader, self.model
        )

        clf = LogisticRegression(
            class_weight="balanced", solver="lbfgs", multi_class="multinomial"
        )
        clf.fit(X_r, Y_r)

        forget_predictions = clf.predict(X_f)
        forget_probabilities = clf.predict_proba(X_f)[:, 1]

        attack_success_rate = forget_predictions.mean()

        avg_membership_prob = forget_probabilities.mean()
        max_membership_prob = forget_probabilities.max()
        min_membership_prob = forget_probabilities.min()

        elapsed = time.time() - start_time

        self.logger.info(f"Membership inference attack completed (time: {elapsed:.2f}s)")
        self.logger.info(f"Attack success rate: {attack_success_rate:.4f}")
        self.logger.info(f"Average membership probability: {avg_membership_prob:.4f}")
        self.logger.info(f"Maximum membership probability: {max_membership_prob:.4f}")
        self.logger.info(f"Minimum membership probability: {min_membership_prob:.4f}")

        return {
            'attack_success_rate': attack_success_rate,
            'avg_membership_prob': avg_membership_prob,
            'max_membership_prob': max_membership_prob,
            'min_membership_prob': min_membership_prob,
            'forget_predictions': forget_predictions,
            'forget_probabilities': forget_probabilities,
            'classifier': clf
        }


def compare_multiple_membership_inference(model_configs, dataset_name='mnist',
                                          forget_class=0, save_dir="results/attack"):
    """
    Compare membership inference attack results across multiple models

    Args:
        model_configs: List of model configurations, format: [{'path': 'path', 'label': 'label'}, ...]
        dataset_name: Name of dataset
        forget_class: Class to forget
        save_dir: Directory to save results
    """
    os.makedirs(save_dir, exist_ok=True)

    attackers = []
    attack_results = []

    for i, config in enumerate(model_configs):
        model_path = config['path']
        model_label = config['label']

        print(f"  Processing {model_label}...")

        try:
            attacker = MembershipInferenceAttack(
                model_path, dataset_name, forget_class,
                log_dir=os.path.join(save_dir, "logs")
            )
            attackers.append(attacker)

            result = attacker.run_attack()
            attack_results.append(result)
        except Exception as e:
            print(f"    Error: Failed to process {model_label} - {e}")
            continue

    if attack_results:
        _create_membership_comparison_visualization(
            attack_results, forget_class,
            [config['label']
                for config in model_configs[:len(attack_results)]], save_dir
        )
        
        csv_path = os.path.join(save_dir, f"membership_inference_results_{dataset_name}_class_{forget_class}.csv")
        with open(csv_path, 'w') as f:
            f.write("Model,Attack_Success_Rate,Avg_Membership_Prob,Max_Membership_Prob,Min_Membership_Prob\n")
            
            for i, result in enumerate(attack_results):
                model_label = model_configs[i]['label']
                f.write(f"{model_label},{result['attack_success_rate']:.6f},{result['avg_membership_prob']:.6f},"
                        f"{result['max_membership_prob']:.6f},{result['min_membership_prob']:.6f}\n")
        
        print(f"  Membership inference attack results saved to: {csv_path}")

    result_dict = {}
    for i, config in enumerate(model_configs[:len(attack_results)]):
        result_dict[config['label']] = attack_results[i]

    return result_dict


def _create_membership_comparison_visualization(attack_results, forget_class,
                                                model_labels, save_dir):
    """Create membership inference attack comparison visualization"""
    num_models = len(attack_results)

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    success_rates = [result['attack_success_rate']
                     for result in attack_results]
    bars1 = axes[0].bar(model_labels, success_rates, alpha=0.7, color=[
                        'blue', 'orange', 'green'][:num_models])
    axes[0].set_title('Membership Inference Attack Success Rate Comparison')
    axes[0].set_ylabel('Attack Success Rate')
    axes[0].set_ylim(0, 1)

    for i, (bar, rate) in enumerate(zip(bars1, success_rates)):
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                     f'{rate:.3f}', ha='center', va='bottom')

    avg_probs = [result['avg_membership_prob'] for result in attack_results]
    bars2 = axes[1].bar(model_labels, avg_probs, alpha=0.7, color=[
                        'blue', 'orange', 'green'][:num_models])
    axes[1].set_title('Average Membership Probability Comparison')
    axes[1].set_ylabel('Average Membership Probability')
    axes[1].set_ylim(0, 1)

    for i, (bar, prob) in enumerate(zip(bars2, avg_probs)):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                     f'{prob:.3f}', ha='center', va='bottom')

    plt.suptitle(f'Membership Inference Attack Comparison - Forget Class: {forget_class}')
    plt.tight_layout()

    save_path = os.path.join(
        save_dir, f"membership_inference_comparison_class_{forget_class}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def compare_membership_inference(original_model_path, unlearned_model_path,
                                 dataset_name='mnist', forget_class=0,
                                 save_dir="results/attack"):
    """
    Compare membership inference attack results between original and unlearned models (for backward compatibility)
    """
    model_configs = [
        {'path': original_model_path, 'label': 'Original Model'},
        {'path': unlearned_model_path, 'label': 'Unlearned Model'}
    ]

    results = compare_multiple_membership_inference(
        model_configs, dataset_name, forget_class, save_dir
    )

    original_result = results.get('Original Model', {})
    unlearned_result = results.get('Unlearned Model', {})
    success_rate_drop = original_result.get('attack_success_rate', 0) - \
        unlearned_result.get('attack_success_rate', 0)
    membership_prob_drop = original_result.get('avg_membership_prob', 0) - \
        unlearned_result.get('avg_membership_prob', 0)

    return {
        'original': original_result,
        'unlearned': unlearned_result,
        'success_rate_drop': success_rate_drop,
        'membership_prob_drop': membership_prob_drop
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Execute membership inference attack evaluation')
    parser.add_argument('--config', type=str, default='attack/model_configs.json',
                        help='Path to model configuration file (JSON)')
    parser.add_argument('--forget-class', type=int, default=0,
                        help='Class to forget (default: 0)')
    parser.add_argument('--output', type=str, default='results/membership_attack_comparison_summary.csv',
                        help='Output CSV file path for results')
    parser.add_argument('--quiet', action='store_true',
                        help='Quiet mode, reduce output information')
    args = parser.parse_args()
    
    model_configs = []
    config_path = os.path.join(os.path.dirname(
        os.path.dirname(os.path.abspath(__file__))), args.config)

    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            config_data = json.load(f)

            if 'models' in config_data:
                if 'dataset' in config_data['models'][0]:
                    for dataset_group in config_data['models']:
                        dataset_name = dataset_group['dataset']
                        for model in dataset_group['models']:
                            model_config = model.copy()
                            model_config['dataset'] = dataset_name
                            model_configs.append(model_config)
                elif 'type' in config_data['models'][0]:
                    for method_group in config_data['models']:
                        method_type = method_group['type']
                        for model in method_group['models']:
                            model_config = model.copy()
                            model_config['type'] = method_type
                            model_configs.append(model_config)
            else:
                model_configs = config_data

        if not args.quiet:
            print(f"Loaded {len(model_configs)} model configurations from {config_path}")
    else:
        print(f"Warning: Configuration file {config_path} does not exist, using default configuration")
        
    datasets = ['mnist', 'cifar10', 'svhn']
    all_results = {}
    
    all_forget_drops = []
    all_retrain_drops = []
    all_forget_prob_drops = []
    all_retrain_prob_drops = []
    
    for dataset_name in datasets:
        if not args.quiet:
            print(f"\nProcessing dataset: {dataset_name.upper()}")

        current_configs = [
            config for config in model_configs if config['dataset'] == dataset_name]

        if not current_configs:
            if not args.quiet:
                print(f"  No model configurations found for {dataset_name}")
            continue

        valid_configs = []
        for config in current_configs:
            if os.path.exists(config['path']):
                valid_configs.append(config)
            else:
                if not args.quiet:
                    print(f"  Warning: Model file does not exist - {config['path']}")

        if not valid_configs:
            if not args.quiet:
                print(f"  No valid model files for {dataset_name}")
            continue

        try:
            results = compare_multiple_membership_inference(
                model_configs=valid_configs,
                dataset_name=dataset_name,
                forget_class=args.forget_class,
                save_dir=f"results/multi_membership_attack_comparison_{dataset_name}"
            )

            all_results[dataset_name] = results
            if not args.quiet:
                print(f"  {dataset_name} attack completed, processed {len(results)} models")

        except Exception as e:
            print(f"  Error processing {dataset_name}: {e}")
            continue

    csv_path = args.output
    os.makedirs(os.path.dirname(csv_path), exist_ok=True)

    with open(csv_path, 'w') as f:
        f.write(
            "Dataset,Model,Original_Attack_Rate,Forget_Attack_Rate,Retrain_Attack_Rate,Forget_Drop,Retrain_Drop,")
        f.write(
            "Original_Prob,Forget_Prob,Retrain_Prob,Forget_Prob_Drop,Retrain_Prob_Drop\n")

        for dataset_name, results in all_results.items():
            model_groups = {}
            for model_label, result in results.items():
                if 'ResNet9' in model_label:
                    model_type = 'ResNet9'
                elif 'LeNet' in model_label:
                    model_type = 'LeNet'
                elif 'CifarResNet18' in model_label:
                    model_type = 'ResNet18'
                else:
                    continue

                if model_type not in model_groups:
                    model_groups[model_type] = {}

                if 'original' in model_label.lower():
                    model_groups[model_type]['original'] = result
                elif 'forget' in model_label.lower():
                    model_groups[model_type]['forget'] = result
                elif 'retrain' in model_label.lower():
                    model_groups[model_type]['retrain'] = result

            for model_type, group_results in model_groups.items():
                original = group_results.get('original', {})
                forget = group_results.get('forget', {})
                retrain = group_results.get('retrain', {})

                orig_rate = original.get('attack_success_rate', 0)
                forget_rate = forget.get('attack_success_rate', 0)
                retrain_rate = retrain.get('attack_success_rate', 0)

                orig_prob = original.get('avg_membership_prob', 0)
                forget_prob = forget.get('avg_membership_prob', 0)
                retrain_prob = retrain.get('avg_membership_prob', 0)

                forget_drop = orig_rate - forget_rate
                retrain_drop = orig_rate - retrain_rate
                forget_prob_drop = orig_prob - forget_prob
                retrain_prob_drop = orig_prob - retrain_prob

                if forget_drop > 0:
                    all_forget_drops.append(forget_drop)
                    all_forget_prob_drops.append(forget_prob_drop)

                if retrain_drop > 0:
                    all_retrain_drops.append(retrain_drop)
                    all_retrain_prob_drops.append(retrain_prob_drop)

                f.write(
                    f"{dataset_name},{model_type},{orig_rate:.6f},{forget_rate:.6f},{retrain_rate:.6f},")
                f.write(
                    f"{forget_drop:.6f},{retrain_drop:.6f},{orig_prob:.6f},{forget_prob:.6f},")
                f.write(
                    f"{retrain_prob:.6f},{forget_prob_drop:.6f},{retrain_prob_drop:.6f}\n")

    print(f"Membership inference attack results saved to: {csv_path}")