import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from datasets.csv_dataset import BreastCancerDataset, CovertypeDataset, DiabetesDataset, WineQualityDataset
from models.logistic_regression import LogisticRegression
from models.linear_regression import LinearRegression
# set recursion limit to 100000 to allow for large budget
sys.setrecursionlimit(100000)

# import model and semivalue
from models.resnet import ResNet18
from measures.g_shapley import GShapley

# import tools
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
# use argparse
import argparse


def parse_args() -> argparse.Namespace:
    # use argparse
    parser = argparse.ArgumentParser(description='Differentially Private Data Valuation')
    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset name (default: cifar10)')
    parser.add_argument('--model', type=str, default='resnet18', metavar='N',
                        help='model name (default: resnet18)')
    parser.add_argument('--num_classes', type=int, default=None, metavar='N',
                        help='number of classes (default: 10)')
    parser.add_argument('--lr', type=float, default=None, metavar='LR',)
    parser.add_argument('--warmup_ratio', type=float, default=0, metavar='WR',)
    parser.add_argument('--clipping_norm', type=float, default=-1, metavar='N',
                        help='clipping norm (default: -1 i.e. no DP)')
    parser.add_argument('--delta', type=float, default=1e-5, metavar='N',)
    parser.add_argument('--epsilon', type=float, default=8.0, metavar='N',)
    parser.add_argument('--budget', type=int, default=1000, metavar='N',
                        help='privacy budget (default: 1000)')
    parser.add_argument('--num_players', type=int, default=100, metavar='N',
                        help='number of players (default: 100)')
    parser.add_argument('--epochs', type=int, default=1, metavar='N',
                        help='number of epochs per training (default: 1)')
    parser.add_argument('--utility', type=str, default=None, metavar='N',
                        help='utility function (default: None (negated loss))')
    parser.add_argument('--use_momentum', action='store_true', default=False,
                        help='use momentum (default: False)')
    parser.add_argument('--random_seed', type=int, default=0, metavar='N',)

    parsed_args = parser.parse_args()

    return parsed_args


def main():
    args = parse_args()
    # global var
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"using {device}")
    num_data = args.num_players
    num_classes = args.num_classes
    budget = args.budget

    if args.clipping_norm < 0:
        experiment_name = f"{args.dataset}_{args.model}_no_dp_budget_{budget}_num_players_{num_data}_seed_{args.random_seed}"
    else:
        
        experiment_name = f"{args.dataset}_{args.model}_clipping_norm_{args.clipping_norm}_eps_{args.epsilon}_delta_{args.delta}_budget_{budget}_num_players_{num_data}_seed_{args.random_seed}_{'use_momentum' if args.use_momentum else 'no_momentum'}"
    

    data_name = f"{args.dataset}_num_players_{num_data}"
    # load dataset
    if args.dataset == 'cifar10':
        in_channels = 3
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}.pt') and os.path.exists(f'./data/val_{data_name}.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}.pt')
            valset = torch.load(f'./data/val_{data_name}.pt')
        else:
            # load CIFAR10 dataset with torchvision
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            )

            trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                    download=True, transform=transform)

            # split trainset into trainset and validationset
            trainset, valset = torch.utils.data.random_split(trainset, [40000, 10000])

            # use only num_data data for training
            trainset, _ = torch.utils.data.random_split(trainset, [num_data, 40000 - num_data])
            # use only 1000 data for validation
            _, valset = torch.utils.data.random_split(valset, [9000, 1000])

            # make sure all 10 classes are included
            targets = [item[1] for item in trainset]
            print("number of data for each class:", np.bincount(targets))

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}.pt')
            torch.save(valset, f'./data/val_{data_name}.pt')
    elif args.dataset == 'breast-cancer':
        dataset = BreastCancerDataset('./data/breast-cancer.csv', {"M": 0, "B": 1}, label_column='diagnosis')
        num_attributes = dataset.num_attributes()
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}.pt') and os.path.exists(f'./data/val_{data_name}.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}.pt')
            valset = torch.load(f'./data/val_{data_name}.pt')
        else:
            
            trainset, valset = torch.utils.data.random_split(dataset, [num_data, len(dataset) - num_data])

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}.pt')
            torch.save(valset, f'./data/val_{data_name}.pt')
    elif args.dataset == 'mnist':
        in_channels = 1
        num_attributes = 28 * 28
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}.pt') and os.path.exists(f'./data/val_{data_name}.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}.pt')
            valset = torch.load(f'./data/val_{data_name}.pt')
        else:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            )

            trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                                    download=True, transform=transform)

            # split trainset into trainset and validationset
            trainset, valset = torch.utils.data.random_split(trainset, [50000, 10000])

            # use only num_data data for training
            trainset, _ = torch.utils.data.random_split(trainset, [num_data, 50000 - num_data])
            # use only 1000 data for validation
            _, valset = torch.utils.data.random_split(valset, [9000, 1000])

            # make sure all 10 classes are included
            targets = [item[1] for item in trainset]
            print("number of data for each class:", np.bincount(targets))

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}.pt')
            torch.save(valset, f'./data/val_{data_name}.pt')
    elif args.dataset == 'wine_quality':
        num_attributes = 11
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}.pt') and os.path.exists(f'./data/val_{data_name}.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}.pt')
            valset = torch.load(f'./data/val_{data_name}.pt')
        else:
            dataset = WineQualityDataset(label_column='quality')
            trainset, valset = torch.utils.data.random_split(dataset, [num_data, len(dataset) - num_data])

            if len(valset) > 1000:
                _, valset = torch.utils.data.random_split(valset, [len(valset) - 1000, 1000])

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}.pt')
            torch.save(valset, f'./data/val_{data_name}.pt')
    elif args.dataset == 'diabetes':
        dataset = DiabetesDataset('./data/diabetes.csv', {0: 0, 1: 1}, label_column='Outcome')
        num_attributes = dataset.num_attributes()
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}.pt') and os.path.exists(f'./data/val_{data_name}.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}.pt')
            valset = torch.load(f'./data/val_{data_name}.pt')
        else:
            trainset, valset = torch.utils.data.random_split(dataset, [num_data, len(dataset) - num_data])

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}.pt')
            torch.save(valset, f'./data/val_{data_name}.pt')
    elif args.dataset == 'covertype':
        num_attributes = 51
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}.pt') and os.path.exists(f'./data/val_{data_name}.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}.pt')
            valset = torch.load(f'./data/val_{data_name}.pt')
        else:
            dataset = CovertypeDataset({1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 7: 6}, label_column='Cover_Type')
            trainset, valset = torch.utils.data.random_split(dataset, [num_data, len(dataset) - num_data])
            # use at most 1000 data for validation
            if len(valset) > 1000:
                _, valset = torch.utils.data.random_split(valset, [len(valset) - 1000, 1000])

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}.pt')
            torch.save(valset, f'./data/val_{data_name}.pt')
    else:
        raise NotImplementedError

    # load model
    if args.model == 'resnet18':
        net = ResNet18(num_channels=in_channels, num_classes=num_classes)
    elif args.model == 'logistic_regression':
        net = LogisticRegression(input_dim=num_attributes, output_dim=num_classes)
    elif args.model == 'linear_regression':
        net = LinearRegression(input_dim=num_attributes)
    else:
        raise NotImplementedError

    # create G-Shapley instance
    g_shapley = GShapley(trainset, valset, net, args=args)

    # run experiment
    all_scores, _, _, _ = g_shapley.run(num_iters=budget, clipping_norm=args.clipping_norm, epsilon=args.epsilon, delta=args.delta)
    # save scores, create a folder called results if not exist
    os.makedirs('./results/budget', exist_ok=True)
    np.save(f'./results/budget/scores_{experiment_name}.npy', all_scores)
    # np.save(f'./results/budget/vars_{experiment_name}.npy', all_vars)
    print("scores saved")


if __name__ == "__main__":
    main()
