import torch
from models import MLP
from data_utils import prepare_data
from fairness_test import test_auto_nammd
from utils import get_output
from config import datasets, HYPERPARAMS

import numpy as np
import os

def train_model(train_dl, model, lr, epochs, device, save_dir=None):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.BCEWithLogitsLoss()
    model.train()
    for epoch in range(epochs):
        for X, y in train_dl:
            X, y = torch.tensor(X, dtype=torch.float32).to(device), y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(save_dir, "model.pt"))

def evaluate_model(dl, model, device):
    model.eval()
    correct, total = 0, 0
    preds, targets = [], []
    with torch.no_grad():
        for X, y in dl:
            X, y = torch.tensor(X, dtype=torch.float32).to(device), y.to(device)
            output = model(X)
            pred = torch.argmax(torch.sigmoid(output), dim=1)
            true = torch.argmax(y, dim=1)
            preds.append(pred.cpu())
            targets.append(true.cpu())
            correct += (pred == true).sum().item()
            total += len(y)
    preds = torch.cat(preds).numpy()
    targets = torch.cat(targets).numpy()
    return correct / total, preds

def run_experiment(dataset_key='stu', n_runs=3, epsilons=[0, 0.1, 0.3]):
    # Load config
    dataset = datasets[dataset_key]
    path = dataset['path']
    test_set = dataset['test_set']
    n_hidden = dataset['n_hidden']
    lr = HYPERPARAMS['lr']
    epochs = HYPERPARAMS['epochs']
    device = HYPERPARAMS['device']

    # Storage
    results = {
        'SAA': {'acc': [], 'unfairness': []},
        'SSA': {'acc': [], 'unfairness': []},
        'SFA': {eps: {'acc': [], 'unfairness': []} for eps in epsilons}
    }

    for run in range(n_runs):
        print(f"\n====== Run {run + 1}/{n_runs} ======")
        # -------- SAA: Train with all features, including sensitive
        print("-------- SAA (all attributes):")
        train_dl, test_dl_0, test_dl_1, input_dim, output_dim, _, _ = prepare_data(path, -1, test_set)
        model_saa = MLP(input_dim, n_hidden, output_dim).to(device)
        train_model(train_dl, model_saa, lr, epochs, device)
        acc_0, _ = evaluate_model(test_dl_0, model_saa, device)
        acc_1, _ = evaluate_model(test_dl_1, model_saa, device)
        unfairness = acc_0 - acc_1
        print(f'Accuracy: {acc_0:.3f}, Unfairness: {unfairness:.3f}')
        results['SAA']['acc'].append(acc_0)
        results['SAA']['unfairness'].append(unfairness)

        # -------- SSA: Remove all sensitive features
        print("-------- SSA (all sensitive attributes removed):")
        train_dl, test_dl_0, test_dl_1, input_dim, output_dim, _, _ = prepare_data(path, test_set)
        model_ssa = MLP(input_dim, n_hidden, output_dim).to(device)
        train_model(train_dl, model_ssa, lr, epochs, device)
        acc_0, _ = evaluate_model(test_dl_0, model_ssa, device)
        acc_1, _ = evaluate_model(test_dl_1, model_ssa, device)
        unfairness = acc_0 - acc_1
        print(f'Accuracy: {acc_0:.3f}, Unfairness: {unfairness:.3f}')
        results['SSA']['acc'].append(acc_0)
        results['SSA']['unfairness'].append(unfairness)

        # --------- SFA: DCFT-guided, remove only unfair attributes (NAMMD)
        # Obtain original representation for unfair attribute test
        print("-------- DCFT/NAMMD Statistical Fairness Test:")
        ori_embedding = get_output(test_dl_0, model_saa, device)
        # Generate counterfactual sets: test each attribute
        test_embeddings = []
        for idx in test_set:
            # Drop each attribute in turn
            _, test_dl_drop, _, _, _, _, _ = prepare_data(path, idx)
            test_embeddings.append(get_output(test_dl_drop, model_saa, device))
        # For each epsilon, test fairness
        for eps in epsilons:
            # Here use dummy/unfair list logic as placeholder; replace with NAMMD/DCFT actual results!
            unfair_move_list = [i for i in range(len(test_set))] if eps == 0 else []
            print(f"-------- SFA (epsilon={eps}): Unfair attributes: {unfair_move_list}")
            if unfair_move_list:
                train_dl, test_dl_0, test_dl_1, input_dim, output_dim, _, _ = prepare_data(path, unfair_move_list)
            else:
                train_dl, test_dl_0, test_dl_1, input_dim, output_dim, _, _ = prepare_data(path)
            model_sfa = MLP(input_dim, n_hidden, output_dim).to(device)
            train_model(train_dl, model_sfa, lr, epochs, device)
            acc_0, _ = evaluate_model(test_dl_0, model_sfa, device)
            acc_1, _ = evaluate_model(test_dl_1, model_sfa, device)
            unfairness = acc_0 - acc_1
            print(f'Accuracy: {acc_0:.3f}, Unfairness: {unfairness:.3f}')
            results['SFA'][eps]['acc'].append(acc_0)
            results['SFA'][eps]['unfairness'].append(unfairness)

    # Print summary
    print("\n====== Experiment Results (Mean ± Std) ======")
    print("SAA:", np.mean(results['SAA']['acc']), "+-", np.std(results['SAA']['acc']),
          "| Unfairness:", np.mean(results['SAA']['unfairness']), "+-", np.std(results['SAA']['unfairness']))
    print("SSA:", np.mean(results['SSA']['acc']), "+-", np.std(results['SSA']['acc']),
          "| Unfairness:", np.mean(results['SSA']['unfairness']), "+-", np.std(results['SSA']['unfairness']))
    for eps in epsilons:
        print(f"SFA (eps={eps}):", np.mean(results['SFA'][eps]['acc']), "+-", np.std(results['SFA'][eps]['acc']),
              "| Unfairness:", np.mean(results['SFA'][eps]['unfairness']), "+-", np.std(results['SFA'][eps]['unfairness']))

    return results

if __name__ == '__main__':
    results = run_experiment(dataset_key='stu', n_runs=3, epsilons=[0, 0.1, 0.3])
    # (You may save results as .csv or plot as needed)
