import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from datasets.csv_dataset import BreastCancerDataset, CovertypeDataset, DiabetesDataset
from datasets.utils import perturb_labels
from models.logistic_regression import LogisticRegression
from models.cnn import CNN
# set recursion limit to 100000 to allow for large budget
sys.setrecursionlimit(100000)

# import model and semivalue
from models.resnet import ResNet18
from measures.g_shapley import GShapley

# import tools
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
# use argparse
import argparse


def parse_args() -> argparse.Namespace:
    # use argparse
    parser = argparse.ArgumentParser(description='Differentially Private Data Valuation')
    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset name (default: cifar10)')
    parser.add_argument('--model', type=str, default='resnet18', metavar='N',
                        help='model name (default: resnet18)')
    parser.add_argument('--num_classes', type=int, default=10, metavar='N',
                        help='number of classes (default: 10)')
    parser.add_argument('--lr', type=float, default=None, metavar='LR',)
    parser.add_argument('--warmup_ratio', type=float, default=0, metavar='WR',)
    parser.add_argument('--clipping_norm', type=float, default=-1, metavar='N',
                        help='clipping norm (default: -1 i.e. no DP)')
    parser.add_argument('--delta', type=float, default=1e-5, metavar='N',)
    parser.add_argument('--epsilon', type=float, default=8.0, metavar='N',)
    parser.add_argument('--budget', type=int, default=1000, metavar='N',
                        help='privacy budget (default: 1000)')
    parser.add_argument('--num_players', type=int, default=100, metavar='N',
                        help='number of players (default: 100)')
    parser.add_argument('--epochs', type=int, default=1, metavar='N',
                        help='number of epochs per training (default: 1)')
    parser.add_argument('--utility', type=str, default=None, metavar='N',
                        help='utility function (default: None (negated loss))')
    parser.add_argument('--use_momentum', action='store_true', default=False,
                        help='use momentum (default: False)')
    parser.add_argument('--flip_ratio', type=float, default=0.2, metavar='N',
                        help='ratio of flipped labels (default: 0.0)')
    parser.add_argument('--random_seed', type=int, default=0, metavar='N',)

    parsed_args = parser.parse_args()

    return parsed_args


def main():
    args = parse_args()
    # global var
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"using {device}")
    num_data = args.num_players
    num_classes = args.num_classes
    budget = args.budget

    if args.clipping_norm < 0:
        experiment_name = f"{args.dataset}_{args.model}_no_dp_budget_{budget}_num_players_{num_data}_utility_{'negated loss' if args.utility is None else args.utility}_seed_{args.random_seed}"
    else:
        
        experiment_name = f"{args.dataset}_{args.model}_warmup_{args.warmup_ratio}_eps_{args.epsilon}_delta_{args.delta}_budget_{budget}_num_players_{num_data}_utility_{'negated loss' if args.utility is None else args.utility}_seed_{args.random_seed}_{'use_momentum' if args.use_momentum else 'no_momentum'}"
    

    data_name = f"{args.dataset}_num_players_{num_data}"
    # load dataset
    if args.dataset == 'cifar10':
        num_attributes = 3
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}_flipped.pt') and os.path.exists(f'./data/val_{data_name}_flipped.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}_flipped.pt')
            valset = torch.load(f'./data/val_{data_name}_flipped.pt')
        else:
            # load CIFAR10 dataset with torchvision
            transform = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

            trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                    download=True, transform=transform)

            # split trainset into trainset and validationset
            trainset, valset = torch.utils.data.random_split(trainset, [num_data, 50000 - num_data])

            # perturb labels of trainset
            trainset = perturb_labels(trainset, args.flip_ratio)

            # use only 1000 data for validation
            _, valset = torch.utils.data.random_split(valset, [50000 - num_data - 1000, 1000])

            # make sure all 10 classes are included
            targets = [item[1] for item in trainset]
            print("number of data for each class:", np.bincount(targets))

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}_flipped.pt')
            torch.save(valset, f'./data/val_{data_name}_flipped.pt')
    elif args.dataset == 'breast-cancer':
        dataset = BreastCancerDataset('./data/breast-cancer.csv', {"M": 0, "B": 1}, label_column='diagnosis')
        num_attributes = dataset.num_attributes()
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}_flipped.pt') and os.path.exists(f'./data/val_{data_name}_flipped.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}_flipped.pt')
            valset = torch.load(f'./data/val_{data_name}_flipped.pt')
        else:
            
            trainset, valset = torch.utils.data.random_split(dataset, [num_data, len(dataset) - num_data])
            # perturb labels of trainset
            trainset = perturb_labels(trainset, args.flip_ratio)
            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}_flipped.pt')
            torch.save(valset, f'./data/val_{data_name}_flipped.pt')
    elif args.dataset == 'mnist':
        num_attributes = 1
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}_flipped.pt') and os.path.exists(f'./data/val_{data_name}_flipped.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}_flipped.pt')
            valset = torch.load(f'./data/val_{data_name}_flipped.pt')
        else:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            )

            trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                                    download=True, transform=transform)

            # split trainset into trainset and validationset
            trainset, valset = torch.utils.data.random_split(trainset, [num_data, 60000 - num_data])

            # perturb labels of trainset
            trainset = perturb_labels(trainset, args.flip_ratio)

            # use only 1000 data for validation
            _, valset = torch.utils.data.random_split(valset, [59000 - num_data, 1000])

            # make sure all 10 classes are included
            targets = [item[1] for item in trainset]
            print("number of data for each class:", np.bincount(targets))

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}_flipped.pt')
            torch.save(valset, f'./data/val_{data_name}_flipped.pt')
    elif args.dataset == 'diabetes':
        dataset = DiabetesDataset('./data/diabetes.csv', {0: 0, 1: 1}, label_column='Outcome')
        num_attributes = dataset.num_attributes()
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}_flipped.pt') and os.path.exists(f'./data/val_{data_name}_flipped.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}_flipped.pt')
            valset = torch.load(f'./data/val_{data_name}_flipped.pt')
        else:
            trainset, valset = torch.utils.data.random_split(dataset, [num_data, len(dataset) - num_data])
            # perturb labels of trainset
            trainset = perturb_labels(trainset, args.flip_ratio)
            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}_flipped.pt')
            torch.save(valset, f'./data/val_{data_name}_flipped.pt')
    elif args.dataset == 'covertype':
        num_attributes = 54
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}_flipped.pt') and os.path.exists(f'./data/val_{data_name}_flipped.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}_flipped.pt')
            valset = torch.load(f'./data/val_{data_name}_flipped.pt')
        else:
            dataset = CovertypeDataset({1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6}, label_column='Cover_Type')
            trainset, valset = torch.utils.data.random_split(dataset, [num_data, len(dataset) - num_data])
            # use at most 1000 data for validation
            if len(valset) > 1000:
                _, valset = torch.utils.data.random_split(valset, [len(valset) - 1000, 1000])
            # perturb labels of trainset
            trainset = perturb_labels(trainset, args.flip_ratio)
            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}_flipped.pt')
            torch.save(valset, f'./data/val_{data_name}_flipped.pt')
    else:
        raise NotImplementedError

    # load model
    if args.model == 'resnet18':
        net = ResNet18(num_classes=num_classes)
    elif args.model == 'logistic_regression':
        net = LogisticRegression(input_dim=num_attributes, output_dim=num_classes)
    elif args.model == "cnn":
        net = CNN(input_dim=num_attributes, output_dim=num_classes)
    else:
        raise NotImplementedError

    # create G-Shapley instance
    g_shapley = GShapley(trainset, valset, net, args=args)

    # # run experiment
    scores_shap, scores_banzhaf, scores_beta_41, scores_beta_161 = g_shapley.run(num_iters=budget, clipping_norm=args.clipping_norm, epsilon=args.epsilon, delta=args.delta)
    # save scores, create a folder called results if not exist
    os.makedirs('./results/q_value', exist_ok=True)
    np.save(f'./results/q_value/scores_shapley_{experiment_name}.npy', scores_shap)
    np.save(f'./results/q_value/scores_banzhaf_{experiment_name}.npy', scores_banzhaf)
    np.save(f'./results/q_value/scores_beta_alpha_4_beta_1_{experiment_name}.npy', scores_beta_41)
    np.save(f'./results/q_value/scores_beta_alpha_16_beta_1_{experiment_name}.npy', scores_beta_161)

    scores_loo = g_shapley.run_loo(clipping_norm=args.clipping_norm, epsilon=args.epsilon, delta=args.delta)
    os.makedirs('./results/q_value', exist_ok=True)
    np.save(f'./results/q_value/scores_loo_{experiment_name}.npy', scores_loo)

    print("scores saved")


if __name__ == "__main__":
    main()
