import os
import operator
import argparse

import torch
import torch.multiprocessing as mp
import numpy as np
from torch.utils.data import ConcatDataset

import utils
import federated
from controllers import *
from data.data.data_builder import get_data

def prepare_data_models(config):
    os.makedirs('cache', exist_ok=True)
    cache_data_path = os.path.join('cache', 'cache_' + config.data.name.replace('/', '_') + f'_seed-{args.seed}' + '.pt')
    print(cache_data_path)
    if os.path.exists(cache_data_path):
        # load cached data if possible
        client_loaders, eval_loader, tokenizer = torch.load(cache_data_path)
        print('Loaded cached data from', cache_data_path)
    else:
        # prepare data
        if 'huggingface' in config.data.name:
            client_loaders, tokenizer = get_data(config)
            eval_loader = client_loaders.pop(0)
            
            for c_idx in range(config.federated.num_clients):
                client_loaders[c_idx] = client_loaders[c_idx + 1]
            del client_loaders[config.federated.num_clients]
        else:
            data = utils.init_data(config)
            tokenizer = None
            print("Creating non-iid train data...")
            client_train_sets = federated.create_noniid_data(data.train_set, config.federated, train=True)
            print("Creating non-iid test data...")
            client_test_sets = federated.create_noniid_data(data.test_set, config.federated, train=False)

            client_loaders = {}
            for c_idx in range(config.federated.num_clients):
                client_loaders[c_idx] = {
                    "train": utils.get_dataloader(client_train_sets[c_idx], 
                                                  batch_size=config.local.train_batch_size, 
                                                  shuffle=True),
                    "test": utils.get_dataloader(client_test_sets[c_idx],
                                                batch_size=config.local.eval_batch_size,
                                                shuffle=False)
                }

            eval_loader = {
                "train": utils.get_dataloader(ConcatDataset(client_train_sets),
                                            batch_size=config.local.eval_batch_size,
                                            shuffle=False),
                "test": utils.get_dataloader(ConcatDataset(client_test_sets),
                                            batch_size=config.local.eval_batch_size,
                                            shuffle=False)
            }
        
        data = (client_loaders, eval_loader, tokenizer)
        torch.save(data, cache_data_path)
        print('Saved cached data to', cache_data_path)

    # create server model
    server_model = utils.init_model(config)
    
    return server_model, tokenizer, client_loaders, eval_loader

def unlearning(controller, id):
    print("********** TRAINING **********")
    controller.train(controller.config.federated.num_rounds)
    print("********** LEAVING **********")
    print(f"Client {id} leaves")
    controller.leave(id)
    print("********** RESUME TRAINING **********")
    controller.train(controller.config.federated.num_rounds)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", type=str, default="configs/mnist.yml",
                        help="path to configuration file")
    parser.add_argument("--unlearning_type", type=str, default="mmt",
                        help="unlearning method")
    parser.add_argument("--seed", type=int, default=1, 
                        help="reproducible random seed")
    parser.add_argument("--parallel", action="store_true",
                        help="whether to parallel training (used in BMT and MMT)")
    parser.add_argument("--save_dir", type=str, default="results/", 
                        help="directory to save results")
    parser.add_argument("--noniid_ratio", type=float, default=None,
                        help="noniid ratio per client")
    args = parser.parse_args()
    utils.set_seed(args.seed)

    print("=========== UNLEARNING WITH NON-UNIFORM PROBABILITIES ===========")
    print("Method:", args.unlearning_type)

    config = utils.load_yaml_object(args.config_path) 
    os.makedirs(args.save_dir, exist_ok=True)
    metadata = {"args": vars(args), "config": config}
    metadata_path = os.path.join(args.save_dir, "metadata.yml")
    utils.dump_yaml_object(metadata_path, metadata)

    if config.parallel.do_parallel:
        mp.set_start_method('spawn', force=True)    # allow spawnning to run multiprocessing
    
    config.save_dir = args.save_dir
    config.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # override non-iid configuration if specified again in the arguments 
    if args.noniid_ratio is not None:
        config.federated.noniid_ratio = args.noniid_ratio 

    # unlearning_probs = np.random.dirichlet(alpha=config.federated.num_clinets)  # alpha = [1, 2, ..., num_clients]
    client_indices = [i for i in range(config.federated.num_clients)] 
    # unlearning_probs = np.array([(i + 1) * 1.0 for i in range(config.federated.num_clients)])
    unlearning_probs = np.array([np.random.random_sample() for _ in client_indices])
    unlearning_probs /= unlearning_probs.sum()
    print('Unlearning probablities:', np.around(unlearning_probs, 3))

    mcmc_removed_clients = np.random.choice(client_indices, size=100, replace=True, p=unlearning_probs)
    print('MCMC removed client ids:', mcmc_removed_clients)
    print('Counting:', sorted(Counter(mcmc_removed_clients).items(), key=operator.itemgetter(1)))

    original_config = config.copy()
    for exp_id, removed_client_idx in enumerate(mcmc_removed_clients):
        config.sequential_unlearning.order = [removed_client_idx]
        exp_dir = os.path.join(args.save_dir, "exp_" + str(exp_id))
        config.save_dir = exp_dir
        os.makedirs(exp_dir, exist_ok=True)
        server_model, tokenizer, client_loaders, eval_loader = prepare_data_models(config) 

        # initialize controller and perform unlearning
        func_args = (server_model, client_loaders, eval_loader)
        func_kwargs = {"config": config, "tokenizer": tokenizer}


        if args.unlearning_type == "mmt":
            controller = MMTController(*func_args, **func_kwargs)
        elif args.unlearning_type == "mmt_shannon":
            controller = ShannonFanoMMTController(*func_args, unlearning_probs=unlearning_probs, **func_kwargs) 
        elif args.unlearning_type == "mmt_huffman":
            controller = HuffmanMMTController(*func_args, unlearning_probs=unlearning_probs, **func_kwargs)
        else:
            raise Exception("Wrong unlearning type: {}".format(args.unlearning_type))
        unlearning(controller, removed_client_idx) 
        config = original_config.copy()