import os
import sys



sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from datasets.csv_dataset import BreastCancerDataset, CovertypeDataset, DiabetesDataset
from datasets.utils import perturb_labels
from models.logistic_regression import LogisticRegression
from models.cnn import CNN
# set recursion limit to 100000 to allow for large budget
sys.setrecursionlimit(100000)

# import model and semivalue
from models.resnet import ResNet18
from measures.fl import FL

# import tools
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
# use argparse
import argparse


def parse_args() -> argparse.Namespace:
    # use argparse
    parser = argparse.ArgumentParser(description='Differentially Private Data Valuation')
    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset name (default: cifar10)')
    parser.add_argument('--model', type=str, default='resnet18', metavar='N',
                        help='model name (default: resnet18)')
    parser.add_argument('--num_classes', type=int, default=None, metavar='N',
                        help='number of classes (default: 10)')
    parser.add_argument('--lr', type=float, default=None, metavar='LR',)
    parser.add_argument('--warmup_ratio', type=float, default=0, metavar='WR',)
    parser.add_argument('--clipping_norm', type=float, default=-1, metavar='N',
                        help='clipping norm (default: -1 i.e. no DP)')
    parser.add_argument('--delta', type=float, default=1e-5, metavar='N',)
    parser.add_argument('--epsilon', type=float, default=8.0, metavar='N',)
    parser.add_argument('--budget', type=int, default=1000, metavar='N',
                        help='privacy budget (default: 1000)')
    parser.add_argument('--num_players', type=int, default=100, metavar='N',
                        help='number of players (default: 100)')
    parser.add_argument('--dataset_size', type=int, default=1, metavar='N',
                        help='number of data points in each dataset (default: None)')
    parser.add_argument('--use_momentum', action='store_true', default=False,
                        help='use momentum (default: False)')
    parser.add_argument('--flip_ratio', type=float, default=0.2, metavar='N',
                        help='ratio of flipped labels (default: 0.0)')
    parser.add_argument('--flip_method', type=str, default='label', metavar='N',
                        help='flipping label or feature (default: label)')
    parser.add_argument('--random_seed', type=int, default=0, metavar='N',)

    parsed_args = parser.parse_args()

    return parsed_args


def main():
    args = parse_args()
    # global var
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"using {device}")
    num_data = args.num_players
    num_classes = args.num_classes
    budget = args.budget

    if args.clipping_norm < 0:
        experiment_name = f"{args.dataset}_{args.model}_no_dp_budget_{budget}_num_players_{num_data}_dataset_size_{args.dataset_size}_seed_{args.random_seed}"
    else:
        
        experiment_name = f"{args.dataset}_{args.model}_clipping_norm_{args.clipping_norm}_eps_{args.epsilon}_delta_{args.delta}_budget_{budget}_num_players_{num_data}_dataset_size_{args.dataset_size}_seed_{args.random_seed}_{'use_momentum' if args.use_momentum else 'no_momentum'}"
    

    data_name = f"{args.dataset}_num_players_{num_data}_{args.dataset_size}_{args.flip_method}"
    # load dataset
    if args.dataset == 'cifar10':
        in_channels = 3
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}_flipped.pt') and os.path.exists(f'./data/val_{data_name}_flipped.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}_flipped.pt')
            valset = torch.load(f'./data/val_{data_name}_flipped.pt')
        else:
            # load CIFAR10 dataset with torchvision
            transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

            trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                    download=True, transform=transform)

            total_num_data = num_data * args.dataset_size

            # split trainset into trainset and validationset
            trainset, valset = torch.utils.data.random_split(trainset, [total_num_data, 50000 - total_num_data])

            # perturb labels of trainset
            trainset = perturb_labels(trainset, args.flip_ratio)

            # use only 1000 data for validation
            _, valset = torch.utils.data.random_split(valset, [50000 - total_num_data - 200, 200])

            # make sure all 10 classes are included
            targets = [item[1] for item in trainset]
            print("number of data for each class:", np.bincount(targets))

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}_flipped.pt')
            torch.save(valset, f'./data/val_{data_name}_flipped.pt')
    elif args.dataset == 'mnist':
        in_channels = 1
        # first check if data with the same budget has been stored in data folder
        if os.path.exists(f'./data/train_{data_name}_flipped.pt') and os.path.exists(f'./data/val_{data_name}_flipped.pt'):
            print("loading data from ./data")
            trainset = torch.load(f'./data/train_{data_name}_flipped.pt')
            valset = torch.load(f'./data/val_{data_name}_flipped.pt')
        else:
            transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)),
                ]
            )

            trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                                    download=True, transform=transform)
            
            total_num_data = num_data * args.dataset_size

            # split trainset into trainset and validationset
            trainset, valset = torch.utils.data.random_split(trainset, [total_num_data, 60000 - total_num_data])

            # perturb labels of trainset
            trainset = perturb_labels(trainset, args.flip_ratio)

            # use only 1000 data for validation
            _, valset = torch.utils.data.random_split(valset, [59000 - total_num_data, 1000])

            # make sure all 10 classes are included
            targets = [item[1] for item in trainset]
            print("number of data for each class:", np.bincount(targets))

            # save the trainset and valset
            os.makedirs('./data', exist_ok=True)
            torch.save(trainset, f'./data/train_{data_name}_flipped.pt')
            torch.save(valset, f'./data/val_{data_name}_flipped.pt')
    else:
        raise NotImplementedError

    # load model
    if args.model == 'resnet18':
        net = ResNet18(num_classes=num_classes)
    elif args.model == "cnn":
        net = CNN(input_dim=in_channels, output_dim=num_classes)
    else:
        raise NotImplementedError

    # create federated learning instance
    fl_instance = FL(trainset, valset, net, args=args)

    # run experiment
    all_scores = fl_instance.run(num_iters=budget, clipping_norm=args.clipping_norm, epsilon=args.epsilon, delta=args.delta)
    # save scores, create a folder called results if not exist
    os.makedirs('./results/fl', exist_ok=True)
    np.save(f'./results/fl/scores_shapley_{experiment_name}.npy', all_scores)
    # np.save(f'./results/momentum/vars_{experiment_name}.npy', all_vars)
    print("scores saved")


if __name__ == "__main__":
    main()
