import os
import logging
import copy
from tqdm import tqdm

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
logger = logging.getLogger(__name__)


def modified_cfg(cfg, cfg_client):

    # create the temp dir for output
    tmp_dir = os.path.join(cfg.outdir, 'temp')
    os.makedirs(tmp_dir, exist_ok=True)

    if cfg.federate.save_to:
        cfg.federate.save_to = os.path.join(cfg.outdir, cfg.federate.save_to)
        os.makedirs(cfg.federate.save_to, exist_ok=True)

    # modification for debug (load a small subset from the whole dataset)
    if cfg.data.is_debug:
        cfg.federate.client_num = 6
        cfg.data.hetero_data_name = [
            'imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg'
        ]
        cfg.data.num_of_client_for_data = [1, 1, 1, 1, 1, 1]
        cfg.federate.total_round_num = 2
        cfg.train.local_update_steps = 2
        # TODO
        cfg.federate.atc_load_from = ''
        cfg.federate.save_to = ''
        cfg.data.cache_dir = ''
        cfg.data.batch_size = 1

    if cfg.model.stage == 'assign':
        downstream_tasks = []
        for group_id, num_clients in enumerate(
                cfg.data.num_of_client_for_data):
            downstream_tasks += [cfg.data.hetero_data_name[group_id]
                                 ] * num_clients
        cfg.model.downstream_tasks = downstream_tasks
    elif cfg.model.stage == 'contrast':
        num_agg_topk = []
        if len(cfg.aggregator.num_agg_topk) > 0:
            for group_id, num_clients in enumerate(
                    cfg.data.num_of_client_for_data):
                num_agg_topk += [cfg.aggregator.num_agg_topk[group_id]] * \
                                num_clients
        else:
            for group_id, num_clients in enumerate(
                    cfg.data.num_of_client_for_data):
                num_agg_topk += [cfg.federate.client_num] * num_clients
        cfg.aggregator.num_agg_topk = num_agg_topk

    # client_config
    if cfg_client is not None:
        with open(os.path.join(cfg.outdir, 'config_client.yaml'),
                  'w') as outfile:
            from contextlib import redirect_stdout
            with redirect_stdout(outfile):
                tmp_cfg = copy.deepcopy(cfg_client)
                tmp_cfg.cfg_check_funcs = []
                print(tmp_cfg.dump())

        num_of_client_for_data = cfg.data.num_of_client_for_data
        client_start_id = 1
        for group_id, num_clients in enumerate(num_of_client_for_data):
            group_cfg = cfg_client['client_group_{}'.format(group_id + 1)]
            if cfg.data.is_debug:
                if group_cfg.train.local_update_steps > 5:
                    group_cfg.train.local_update_steps = 5
                group_cfg.data.batch_size = 1
            for client_id in range(client_start_id,
                                   client_start_id + num_clients):
                cfg_client['client_{}'.format(client_id)] = group_cfg
            client_start_id += num_clients

    return cfg, cfg_client


def load_heteroNLP_data(config, client_cfgs):

    from torch.utils.data import DataLoader
    from federatedscope.nlp.hetero_tasks.dataset.utils import setup_tokenizer
    from federatedscope.nlp.hetero_tasks.dataset.get_data import \
        HeteroNLPDataLoader, SynthDataProcessor
    from federatedscope.nlp.hetero_tasks.dataloader.datacollator import \
        DataCollator

    class HeteroNLPDataset(object):
        ALL_DATA_NAME = ['imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg']

        def __init__(self, config, client_cfgs):

            self.root = config.data.root
            self.num_of_clients = config.data.num_of_client_for_data
            self.data_name = list()
            for each_data in config.data.hetero_data_name:
                if each_data not in self.ALL_DATA_NAME:
                    logger.warning(f'We have not provided {each_data} in '
                                   f'HeteroNLPDataset.')
                else:
                    self.data_name.append(each_data)

            data = HeteroNLPDataLoader(
                data_dir=self.root,
                data_name=self.data_name,
                num_of_clients=self.num_of_clients).get_data()

            self.processed_data = self._preprocess(config, client_cfgs, data)

        def __getitem__(self, idx):
            return self.processed_data[idx]

        def __len__(self):
            return len(self.processed_data)

        def _preprocess(self, config, client_cfgs, data):

            use_pretrain_task = config.model.stage == 'assign'
            use_contrastive = config.model.stage == 'contrast'
            tokenizer = setup_tokenizer(config.model.model_type)
            data_collator = DataCollator(tokenizer=tokenizer) \
                if use_pretrain_task else None
            is_debug = config.data.is_debug  # load a subset of data

            processed_data = list()
            for client_id in tqdm(range(1, config.federate.client_num + 1)):
                applied_cfg = config if use_pretrain_task \
                    else client_cfgs['client_{}'.format(client_id)]

                cur_task = applied_cfg.model.downstream_tasks[client_id - 1] \
                    if use_pretrain_task else applied_cfg.model.task

                train_data, val_data, test_data = [
                    self._process_data(data=data[split][client_id - 1],
                                       data_name=cur_task,
                                       split=split,
                                       tokenizer=tokenizer,
                                       model_type=config.model.model_type,
                                       cache_dir=config.data.cache_dir,
                                       cfg=applied_cfg.data,
                                       client_id=client_id,
                                       pretrain=use_pretrain_task,
                                       is_debug=is_debug)
                    for split in ['train', 'val', 'test']
                ]

                dataloader = {}
                dataloader['val'] = {
                    'dataloader': DataLoader(
                        dataset=val_data[0],
                        batch_size=applied_cfg.data.batch_size,
                        shuffle=False,
                        num_workers=config.data.num_workers,
                        collate_fn=data_collator,
                        pin_memory=config.use_gpu),
                    'encoded': val_data[1],
                    'examples': val_data[2]
                }
                dataloader['test'] = {
                    'dataloader': DataLoader(
                        dataset=test_data[0],
                        batch_size=applied_cfg.data.batch_size,
                        shuffle=False,
                        num_workers=config.data.num_workers,
                        collate_fn=data_collator,
                        pin_memory=config.use_gpu),
                    'encoded': test_data[1],
                    'examples': test_data[2]
                }

                if not use_contrastive:
                    dataloader['train'] = {
                        'dataloader': DataLoader(
                            dataset=train_data[0],
                            batch_size=applied_cfg.data.batch_size,
                            shuffle=config.data.shuffle,
                            num_workers=config.data.num_workers,
                            collate_fn=data_collator,
                            pin_memory=config.use_gpu),
                        'encoded': train_data[1],
                        'examples': train_data[2]
                    }
                else:
                    dataloader['train_raw'] = {
                        'dataloader': DataLoader(
                            dataset=train_data[0],
                            batch_size=applied_cfg.data.batch_size,
                            shuffle=config.data.shuffle,
                            num_workers=config.data.num_workers,
                            collate_fn=data_collator,
                            pin_memory=config.use_gpu),
                        'encoded': train_data[1],
                        'examples': train_data[2]
                    }
                    dataloader['train_contrast'] = {
                        'dataloader': DataLoader(
                            dataset=train_data[0],
                            batch_size=applied_cfg.data.batch_size,
                            shuffle=False,
                            num_workers=config.data.num_workers,
                            collate_fn=data_collator,
                            pin_memory=config.use_gpu),
                        'encoded': train_data[1],
                        'examples': train_data[2]
                    }
                processed_data.append(dataloader)

            if use_contrastive:
                logger.info(
                    'Preprocessing synthetic dataset for contrastive learning')
                synth_data_processor = SynthDataProcessor(
                    config, processed_data)
                synth_data_processor.save_data()

            return processed_data

        def _process_data(self, data, data_name, split, tokenizer, model_type,
                          cache_dir, cfg, client_id, pretrain, is_debug):
            if data_name == 'imdb':
                from federatedscope.nlp.hetero_tasks.dataset.imdb import \
                    process_imdb_dataset
                process_func = process_imdb_dataset
            elif data_name == 'agnews':
                from federatedscope.nlp.hetero_tasks.dataset.agnews import \
                    process_agnews_dataset
                process_func = process_agnews_dataset
            elif data_name == 'squad':
                from federatedscope.nlp.hetero_tasks.dataset.squad import \
                    process_squad_dataset
                process_func = process_squad_dataset
            elif data_name == 'newsqa':
                from federatedscope.nlp.hetero_tasks.dataset.newsqa import \
                    process_newsqa_dataset
                process_func = process_newsqa_dataset
            elif data_name == 'cnndm':
                from federatedscope.nlp.hetero_tasks.dataset.cnndm import \
                    process_cnndm_dataset
                process_func = process_cnndm_dataset
            elif data_name == 'msqg':
                from federatedscope.nlp.hetero_tasks.dataset.msqg import \
                    process_msqg_dataset
                process_func = process_msqg_dataset
            else:
                raise NotImplementedError(
                    f'Not process function is provided for {data_name}')

            max_seq_len = getattr(cfg, 'max_seq_len', None)
            max_query_len = getattr(cfg, 'max_query_len', None)
            trunc_stride = getattr(cfg, 'trunc_stride', None)
            max_tgt_len = getattr(cfg, 'max_tgt_len', None)

            return process_func(data=data,
                                split=split,
                                tokenizer=tokenizer,
                                max_seq_len=max_seq_len,
                                max_query_len=max_query_len,
                                trunc_stride=trunc_stride,
                                max_src_len=max_seq_len,
                                max_tgt_len=max_tgt_len,
                                model_type=model_type,
                                cache_dir=cache_dir,
                                raw_cache_dir=cache_dir,
                                client_id=client_id,
                                pretrain=pretrain,
                                is_debug=is_debug)

    modified_config, modified_client_cfgs = modified_cfg(config, client_cfgs)
    dataset = HeteroNLPDataset(modified_config, modified_client_cfgs)
    # Convert to dict
    datadict = {
        client_id + 1: dataset[client_id]
        for client_id in range(len(dataset))
    }

    return datadict, modified_config
