import os
import sys
sys.path.insert(0, './')
import json
import yaml
import pickle
import argparse

import time
import numpy as np

import torch
import torch.nn as nn
from datetime import datetime

from opacus.validators import ModuleValidator

from util.DataParser import parse_data_mia, parse_shadow_data
from util.DeviceParser import parse_device
from util.MIAParser import parse_mia
from util.ModelParser import parse_model
from util.ParamParser import *
from util.Eval import AverageCalculator
from inspect import isfunction
from util.MetricsCalculation import accuracy, precision_p, precision_n, recall_p, recall_n, f1_p, f1_n, auc, tpr_fpr, get_metric


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    parser.add_argument('--yaml', type = str, help = 'The yaml file to load default setting.')

    parser.add_argument('--mia_mode', type=str,  choices=["attack", "eval"])
    parser.add_argument('--mia_method', type = str, help = 'The method of the mia deployed.')
    parser.add_argument('--mia_params', action = DictParser, default = {}, help = 'Additional parameters of the mia deployed.')
    parser.add_argument('--fit_params', action = DictParser, default = {}, help = 'Additional parameters for MIA fitting.')

    parser.add_argument('--member_data', action = DictParser, help = 'The config of member data.')
    parser.add_argument('--nonmember_data', action = DictParser, help = 'The config of non-member data.')

    parser.add_argument('--train_data', action = DictParser, help = 'The config of the training data.')
    parser.add_argument('--shadow_data', action = DictParser, help = 'The config of the shadow data.')

    parser.add_argument('--arch', type = str, help = 'The model architecture.')
    parser.add_argument('--dataset', type = str, help = 'The dataset used to train the model.')
    parser.add_argument('--normalize', type = str, help = 'Whether the data is normalized, default is False.')
    parser.add_argument('--model2load', type = str, help = 'The pretrained model to load.')

    parser.add_argument('--out_file', type = str, help = 'The output file.')
    parser.add_argument('--gpu', type = str, help = 'Specify the GPU to use.')
    parser.add_argument('--load_epoch', type=int, default=200)
    parser.add_argument('--total_epoch', type=int, default=200)

    parser.add_argument('--is_split', action = 'store_true', help = 'The split of the dataset.')
    parser.add_argument('--split_seed', type=int, default=42, help='Split seed.')
    parser.add_argument('--is_shadow', action='store_true', help='If training a shadow model.')
    parser.add_argument('--shadow_ratio', type=float, default=0.8)
    parser.add_argument('--mislabel_ratio', type=float, default=0.)

    args = parser.parse_args()

    # Default configuration
    config = {key: None for key, value in args._get_kwargs()}

    # Load YAML file
    with open(args.yaml, 'r') as fopen:
        yaml_config = yaml.safe_load(fopen)
    for key, value in yaml_config.items():
        config[key] = value

    # Load command line config
    for key, value in vars(args).items():
        if config[key] is None or (value is not None and value != {}):
            config[key] = value

    # Original model and shadow model loading epoch
    config['fit_params']['load_epoch'] = config['load_epoch']
    config['fit_params']['total_epoch'] = config['total_epoch']
    config['mia_params']['load_epoch'] = config['load_epoch']
    config['mia_params']['total_epoch'] = config['total_epoch']

    # Config GPUs
    parse_device(config['gpu'])
    use_gpu = config['gpu'] != 'cpu' and torch.cuda.is_available()
    device = torch.device('cuda:0' if use_gpu else 'cpu')

    # Parse model
    model = parse_model(dataset = config['dataset'], arch = config['arch'], normalize = config.get("normalize", True))
    model = model.cuda() if use_gpu else model

    state_dict = torch.load(config['model2load'], weights_only=True)
    if "dpsgd" in config['model2load']:
        errors = ModuleValidator.validate(model, strict=False)
        if errors:
            model = ModuleValidator.fix(model)

        new_state_dict = {}
        for k, v in state_dict.items():
            name = k
            if name.startswith('_module.'):
                name = name[8:]
            new_state_dict[name] = v
        state_dict = new_state_dict
    model.load_state_dict(state_dict)

    # Parse data
    config['member_data']['mislabel_ratio'] = config['mislabel_ratio']
    config['nonmember_data']['mislabel_ratio'] = config['mislabel_ratio']
    if config['shadow_data'] is not None:
        config['shadow_data']['mislabel_ratio'] = config['mislabel_ratio']
    
    shadow_member_loaders, shadow_nonmember_loaders = parse_shadow_data(**config['shadow_data']) if config['shadow_data'] is not None else (None, None)

    # fix random seed
    # member_train_loader = parse_data_mia(member_train=True, **config['member_data'])
    # member_test_loader = parse_data_mia(member_test=True, **config['member_data'])
    # nonmember_train_loader = parse_data_mia(nonmember_train=True, **config['nonmember_data'])
    # nonmember_test_loader = parse_data_mia(nonmember_test=True, **config['nonmember_data'])
    member_loader = parse_data_mia(**config['member_data'])
    nonmember_loader = parse_data_mia(**config['nonmember_data'])

    # ParseMIA
    attacker = parse_mia(method = config['mia_method'], mia_mode=config["mia_mode"], **config['mia_params'])  # indicate mia_mode here

    # Prepare the item to save
    tosave = {'model_summary': str(model), 'config': config, 'runtime': None, 
              'member_pred': {}, 'nonmember_pred': {}, 'prob_member':{}, 'prob_nonmember':{},
        'stats': {}, 'log': {'cmd': 'python ' + ' '.join(sys.argv), 'time': datetime.now().strftime('%Y/%m/%d, %H:%M:%S')}}

    for key in list(sorted(config.keys())):
        print('%s\t=>%s' % (key, config[key]))

    print("\n=== Debug: Checking batch format from member_loader ===")
    for batch in member_loader:
    # for batch in member_loader:
        print("Batch type:", type(batch))
        print("Batch len:", len(batch))
        print("Batch[0] type:", type(batch[0]))
        break  

    # Fit MIA model
    fit_data_loaders = {
        "member_train": member_loader,
        "nonmember_train": nonmember_loader,
        "shadow_member": shadow_member_loaders,
        "shadow_nonmember": shadow_nonmember_loaders,
    }

    attacker.fit(model, fit_data_loaders, **config['fit_params'])


    # MIA Inference
    member_acc = AverageCalculator()
    nonmember_acc = AverageCalculator()
    aggretation_result = None       # aggregation result 


    for idx, (member_data, nonmember_data) in enumerate(zip(member_loader, nonmember_loader)):
        if len(member_data) == 3:
            member_data_batch, member_label_batch, member_idx_batch = member_data
            nonmember_data_batch, nonmember_label_batch, nonmember_idx_batch = nonmember_data
            nonmember_idx_batch = - nonmember_idx_batch
        elif len(member_data) == 4:
            member_data_batch, member_label_batch, true_member_label_batch, member_idx_batch = member_data
            nonmember_data_batch, nonmember_label_batch, true_nonmember_label_batch, nonmember_idx_batch = nonmember_data
            nonmember_idx_batch = - nonmember_idx_batch
        else:
            raise ValueError('The input data is not valid!')

        data_batch = torch.cat((member_data_batch, nonmember_data_batch))
        label_batch = torch.cat((member_label_batch, nonmember_label_batch))
        idx_batch = torch.cat((member_idx_batch, nonmember_idx_batch))
        data_batch = data_batch.cuda(device) if use_gpu else data_batch
        label_batch = label_batch.cuda(device) if use_gpu else label_batch
        idx_batch = idx_batch.cuda(device) if use_gpu else idx_batch

        # get all output
        infer_result = attacker.infer(model, data_batch, label_batch)

        idx_member = idx_batch[:member_data_batch.size(0)]
        idx_nonmember = idx_batch[member_data_batch.size(0):]

        # for attack mode
        if config["mia_mode"] == "attack":
            result_this_batch = infer_result[0]


            batch_result_member = result_this_batch[:member_data_batch.size(0)]
            batch_result_nonmember = result_this_batch[member_data_batch.size(0):] 

            for instance_idx, instance_result in zip(idx_member, batch_result_member):
                tosave['member_pred'][int(instance_idx)] = int(instance_result)

            for instance_idx, instance_result in zip(idx_nonmember, batch_result_nonmember):
                tosave['nonmember_pred'][int(instance_idx)] = int(instance_result) 


            result = get_metric(nonmember_acc, member_acc, 
               batch_result_member, batch_result_nonmember, 
               tosave)

            aggretation_result = result

            sys.stdout.write('Member: Batch Idx: %d - Member Accuracy: %.2f%% - Nonmember Accuracy: %.2f%%\r' \
                            % (idx, 100*result["member_acc"], 100*result["nonmember_acc"]))

    if hasattr(attacker,"output") and config["mia_mode"] == "eval":
        aggretation_result = attacker.output()

        result_member = aggretation_result["member_pred"]
        result_nonmember = aggretation_result["nonmember_pred"]

    else:
        result_member = torch.tensor(list(tosave['member_pred'].values()))
        result_nonmember = torch.tensor(list(tosave['nonmember_pred'].values()))

    tp, fn, tn, fp = [aggretation_result.get(k) for k in ("tp", "fn", "tn", "fp")]

    acc = accuracy(result_member, result_nonmember)
    prec_p = precision_p(result_member, result_nonmember)
    rec_p = recall_p(result_member, result_nonmember)
    prec_n = precision_n(result_member, result_nonmember)
    rec_n = recall_n(result_member, result_nonmember)
    f1_positive = f1_p(result_member, result_nonmember)
    f1_negative = f1_n(result_member, result_nonmember)


    print(f"Epoch:{args.load_epoch}, Accuracy: {acc * 100.:.2f}%, Precision@P: {prec_p * 100.:.2f}%, \
          Recall@P: {rec_p * 100.:.2f}%, F1@P: {f1_positive * 100.:.2f}%, Precision@N: {prec_n * 100.:.2f}%, \
          Recall@N: {rec_n * 100.:.2f}%, F1@N: {f1_negative * 100.:.2f}%, Auc: {aggretation_result.get('auc',0.0):.3f}, \
          TPR under 0.1%% FPR: {aggretation_result.get('tpr01fpr')}, TPR under 0.01%% FPR: {aggretation_result.get('tpr001fpr')}")    

    tosave['stats'] = {'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn,
                       'member_acc': member_acc.average, 'nonmember_acc': nonmember_acc.average,
                       'accuracy': acc, 'precision@p': prec_p, 'recall@p': rec_p, 'precision@n': prec_n, 'recall@n': rec_n,
                       'f1@p': f1_positive, 'f1@n': f1_negative,
                       "auc": aggretation_result.get("auc") ,
                       "tpr under 0.1%% fpr": aggretation_result.get("tpr01fpr"),
                       "tpr under 0.01%% fpr": aggretation_result.get("tpr001fpr"),
                       }

    # Save the result
    out_folder = os.path.dirname(config['out_file'])
    if out_folder != '' and not os.path.exists(out_folder):
        os.makedirs(out_folder)
    pickle.dump(tosave, open(config['out_file'], 'wb'))

