import os
import argparse

import torch
import numpy as np
import torch.multiprocessing as mp
from torch.utils.data import ConcatDataset
from sklearn.model_selection import train_test_split

from src import utils, federated
from src.controllers import *
from data.data.data_builder import get_data


def sequential_unlearning(controller, config):
    print('********** TRAINING **********')
    controller.train(controller.config.federated.num_rounds)
    for id in config.order:
        print('********** LEAVING **********')
        print(f'Client {id} leaves')
        controller.leave(id)
        print('********** RESUME TRAINING **********')
        controller.train(controller.config.federated.num_rounds)
    return

def continual_unlearning(controller, config):
    for req in config.requests:
        if req['request_type'] == 'train':
            print('********** TRAINING **********')
            controller.train(req['num_rounds'])
        elif req['request_type'] == 'leave':
            print('********** LEAVING **********')
            print(f'Client {req['client_id']} leaves')
            controller.leave(req['client_id'])
        elif req['request_type'] == 'join':
            controller.join(req['client_id'], req['dataloader'], req['attach_to'])
        # print(req)
        # import pdb; pdb.set_trace()
    return
     

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--setting', type=str, choices=['sequential', 'continual'],
                        help='Experiment setting')
    parser.add_argument('--config_path', type=str, default='configs/mnist.yml',
                        help='path to configuration file')
    parser.add_argument('--method', type=str, default='mmt',
                        help='unlearning method')
    parser.add_argument('--seed', type=int, default=1, 
                        help='reproducible random seed')
    parser.add_argument('--parallel', action='store_true',
                        help='whether to parallel training (used in BMT and MMT)')
    parser.add_argument('--save_dir', type=str, default='results/', 
                        help='directory to save results')
    parser.add_argument('--noniid_ratio', type=float, default=None,
                        help='noniid ratio per client')
    args = parser.parse_args()
    utils.set_seed(args.seed)

    print('=========== SEQUENTIAL SETTING ===========')
    print('Method:', args.method)

    config = utils.load_yaml_object(args.config_path) 
    os.makedirs(args.save_dir, exist_ok=True)
    metadata = {'args': vars(args), 'config': config}
    metadata_path = os.path.join(args.save_dir, 'metadata.yml')
    utils.dump_yaml_object(metadata_path, metadata)

    if config.parallel.do_parallel:
        mp.set_start_method('spawn', force=True)    # allow spawnning to run multiprocessing
    
    config.save_dir = args.save_dir
    config.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # override non-iid configuration if specified again in the arguments 
    if args.noniid_ratio is not None:
        config.federated.noniid_ratio = args.noniid_ratio
    
    os.makedirs('cache', exist_ok=True)
    cache_data_path = os.path.join('cache', 'cache_unequal' + config.data.name.replace('/', '_') + f'_seed-{args.seed}' + '.pt')
    print(cache_data_path)
    if os.path.exists(cache_data_path):
        # load cached data if possible
        client_loaders, eval_loader, tokenizer = torch.load(cache_data_path)
        print('Loaded cached data from', cache_data_path)
    else:
        # prepare data
        if 'huggingface' in config.data.name:
            client_loaders, tokenizer = get_data(config)
            eval_loader = client_loaders.pop(0)
            
            for c_idx in range(config.federated.num_clients):
                client_loaders[c_idx] = client_loaders[c_idx + 1]
            del client_loaders[config.federated.num_clients]
        else:
            data = utils.init_data(config)
            tokenizer = None
            print('Creating non-iid train data...')
            client_train_sets = federated.create_noniid_data(data.train_set, config.federated, train=True)
            print('Creating non-iid test data...')
            client_test_sets = federated.create_noniid_data(data.test_set, config.federated, train=False)

            n_clients = len(client_train_sets)
            for i in range(n_clients):
                transfer_client_id = np.random.randint(low=0, high=n_clients)
                transfer_portion = min(0.9, np.random.random())
                
                train_data = client_train_sets[i]
                keep_portion, transfer_portion = train_test_split(train_data, test_size=transfer_portion, random_state=args.seed, shuffle=True)
                client_train_sets[i] = keep_portion
                client_train_sets[transfer_client_id] = ConcatDataset((client_train_sets[transfer_client_id], transfer_portion))
                print(f'To give client {i}:', len(client_train_sets[i]))
                print(f'To receive client {transfer_client_id}:', len(client_train_sets[transfer_client_id]))

            for i in range(n_clients):
                print('Client {}: {} train samples'.format(i, len(client_train_sets[i])))
            
            client_loaders = {}
            for c_idx in range(config.federated.num_clients):
                client_loaders[c_idx] = {
                    'train': utils.get_dataloader(client_train_sets[c_idx], 
                                                  batch_size=config.local.train_batch_size, 
                                                  shuffle=True),
                    'test': utils.get_dataloader(client_test_sets[c_idx],
                                                batch_size=config.local.eval_batch_size,
                                                shuffle=False)
                }

            eval_loader = {
                'train': utils.get_dataloader(ConcatDataset(client_train_sets),
                                            batch_size=config.local.eval_batch_size,
                                            shuffle=False),
                'test': utils.get_dataloader(ConcatDataset(client_test_sets),
                                            batch_size=config.local.eval_batch_size,
                                            shuffle=False)
            }
        
        data = (client_loaders, eval_loader, tokenizer)
        torch.save(data, cache_data_path)
        print('Saved cached data to', cache_data_path)

    # create server model
    server_model = utils.init_model(config)

    if args.setting == 'continual':
        for req in config.continual_unlearning.requests:
            if req['request_type'] == 'join':
                req['dataloader'] = client_loaders.pop(req['client_id'])

    # initialize controller and perform unlearning
    func_args = (server_model, client_loaders, eval_loader)
    func_kwargs = {'config': config, 'tokenizer': tokenizer}

    if args.method == 'retraining':
        controller = RetrainingController(*func_args, **func_kwargs)
    elif 'fedcio' in args.method:
        num_clusters = int(args.method.split('_')[-1])  # e.g. fedcio_3 means clients are clustered into 3 groups
        controller = FedCIOController(*func_args, num_clusters=num_clusters, seed=args.seed, **func_kwargs)
    elif args.method == 'bmt':
        controller = BMTController(*func_args, **func_kwargs)
    elif args.method == 'mmt':
        controller = MMTController(*func_args, **func_kwargs)
    elif 'subopt_mmt' in args.method:
        branching_factor = int(args.method.split('_')[-1]) # e.g. subopt_mmt_3, subopt_mmt_6
        controller = MMTController(*func_args, k=branching_factor, **func_kwargs)
    elif args.method == 'standalone':
        controller = StandaloneController(*func_args, **func_kwargs)
    elif args.method == 'greedy':
        subscription_seq = config.sequential_unlearning.order.copy()
        for client_idx in range(config.federated.num_clients):
            if client_idx not in subscription_seq:
                subscription_seq.append(client_idx)
        controller = GreedyController(*func_args, subscription_seq=subscription_seq, **func_kwargs)
    else:
        raise Exception('Wrong unlearning type: {}'.format(args.method))
    
    if args.setting == 'sequential':
        sequential_unlearning(controller, config.sequential_unlearning) 
    elif args.setting == 'continual':
        continual_unlearning(controller, config.continual_unlearning)