import argparse
import os
import torch
import helper
from torch.utils.data import ConcatDataset
from helper import OFL, Q_FL
from src import utils, federated

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def sequential_unlearning(
    q_FL, traindata_splits, train_loaders, testloader, config, args,
):
    # train quantized federated model
    print("********** TRAINING **********")
    q_FL.global_train(train_loaders, testloader, 
                      client_iters=config.local.num_epochs,
                      learning_rate=config.local.lr,
                      grad_clip=config.local.grad_clip,
                      weight_decay=config.local.weight_decay,
                      momentum=config.local.momentum)

    for id in config.sequential_unlearning.order:
        print("********** LEAVING **********")
        print(f"Client {id} leaves")
        retrain, deleted_data_subset, traindata_splits, train_loaders = q_FL.delete(traindata_splits, 
                                                                                    client_id=id, 
                                                                                    delete_ratio=1.0, 
                                                                                    testloader=testloader)
        q_FL.global_train(train_loaders, testloader, 
                          client_iters=config.local.num_epochs,
                          learning_rate=config.local.lr,
                          grad_clip=config.local.grad_clip,
                          weight_decay=config.local.weight_decay,
                          momentum=config.local.momentum)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, help='Dataset name', default='CIFAR10')
    parser.add_argument('--config_path', type=str, default='../../configs/mnist-sequential.yml', help='path to configuration file')
    parser.add_argument('--seed', type=int, default=1, help='reproducible random seed')
    parser.add_argument('--out_dir', type=str, help='Output directory', default='results/CIFAR10')
    parser.add_argument('--q_alpha', type=float, default=0.002, help='Quantization magnitude')
    args = parser.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    utils.set_seed(args.seed)


    config = utils.load_yaml_object(args.config_path) 

    # plug in our data
    os.makedirs('cache', exist_ok=True)
    cache_data_path = os.path.join('cache', 'cache_' + config.data.name.replace('/', '_') + f'_seed-{args.seed}' + '.pt')
    print(cache_data_path)

    if os.path.exists(cache_data_path):
        # load cached data if possible
        client_loaders, eval_loader, tokenizer = torch.load(cache_data_path)
        print('Loaded cached data from', cache_data_path)
    else:
        data = utils.init_data(config, root_dir='~/data/')
        tokenizer = None
        print('Creating non-iid train data...')
        client_train_sets = federated.create_noniid_data(data.train_set, config.federated, train=True)
        print('Creating non-iid test data...')
        client_test_sets = federated.create_noniid_data(data.test_set, config.federated, train=False)

        client_loaders = {}
        for c_idx in range(config.federated.num_clients):
            client_loaders[c_idx] = {
                'train': utils.get_dataloader(client_train_sets[c_idx], 
                                                batch_size=config.local.train_batch_size, 
                                                shuffle=True),
                'test': utils.get_dataloader(client_test_sets[c_idx],
                                            batch_size=config.local.eval_batch_size,
                                            shuffle=False)
            }

        eval_loader = {
            'train': utils.get_dataloader(ConcatDataset(client_train_sets),
                                        batch_size=config.local.eval_batch_size,
                                        shuffle=False),
            'test': utils.get_dataloader(ConcatDataset(client_test_sets),
                                        batch_size=config.local.eval_batch_size,
                                        shuffle=False)
        }
        data = (client_loaders, eval_loader, tokenizer)
        torch.save(data, cache_data_path)
        print('Saved cached data to', cache_data_path)

    # convert to exact-fun data loaders
    train_loaders = {}
    traindata_splits = {}
    for c_idx in client_loaders.keys():
        train_loaders[c_idx] = client_loaders[c_idx]['train']
        traindata_splits[c_idx] = train_loaders[c_idx].dataset

    testloader = eval_loader['test']

    # create server model
    server_model = utils.init_model(config)
    server_model.to(device)
    print(server_model)
    
    q_FL = Q_FL(server_model, 
                num_clients=config.federated.num_clients, 
                rounds=config.federated.num_rounds,
                batch_size=config.local.train_batch_size, 
                q_alpha=args.q_alpha,
                log_path=os.path.join(args.out_dir, 'logs.json'),
                device=device)
    
    sequential_unlearning(q_FL, traindata_splits, train_loaders, testloader, config, args)


    # # compute metrics
    # efficiency = helper.SPAE(q_FL.global_model, original_FL.global_model, 'eff', 
    #                          testloader, deleted_data_ls2set, device=device)
    # print('Efficiency:', efficiency)

    # certify = helper.SPAE(q_FL.global_model, original_FL.global_model, ' ', 
    #                       testloader, deleted_data_ls2set, device=device)
    # print('Certify:', certify)
