import os
import logging
import random
import csv
import json
import gzip
import zipfile
import shutil
import copy
import torch
import numpy as np
from tqdm import tqdm
from federatedscope.core.data.utils import download_url
from federatedscope.core.hpu_manager import HPUManager
from federatedscope.core.gpu_manager import GPUManager
from federatedscope.nlp.hetero_tasks.model.model import ATCModel

logger = logging.getLogger(__name__)


class HeteroNLPDataLoader(object):
    """
    Load hetero NLP task datasets (including multiple datasets), split them
    into train/val/test, and partition into several clients.
    """
    def __init__(self, data_dir, data_name, num_of_clients, split=[0.9, 0.1]):
        self.data_dir = data_dir
        self.data_name = data_name
        self.num_of_clients = num_of_clients
        self.split = split  # split for train:val
        self.train_data = []
        self.val_data = []
        self.test_data = []

    def get_data(self):
        for each_data, each_client_num in zip(self.data_name,
                                              self.num_of_clients):
            train_and_val_data = self._load(each_data, 'train',
                                            each_client_num)
            each_train_data = [
                data[:int(self.split[0] * len(data))]
                for data in train_and_val_data
            ]
            each_val_data = [
                data[-int(self.split[1] * len(data)):]
                for data in train_and_val_data
            ]
            each_test_data = self._load(each_data, 'test', each_client_num)
            self.train_data.extend(each_train_data)
            self.val_data.extend(each_val_data)
            self.test_data.extend(each_test_data)

        return {
            'train': self.train_data,
            'val': self.val_data,
            'test': self.test_data
        }

    def _load(self, dataset, split, num_of_client):
        data_dir = os.path.join(self.data_dir, dataset)
        if not os.path.exists(data_dir):
            logger.info(f'Start tp download the dataset {dataset} ...')
            self._download_and_extract(dataset)

        # read data
        data = []
        if dataset == 'imdb':
            pos_files = os.listdir(os.path.join(data_dir, split, 'pos'))
            neg_files = os.listdir(os.path.join(data_dir, split, 'neg'))
            for file in pos_files:
                path = os.path.join(data_dir, split, 'pos', file)
                with open(path) as f:
                    line = f.readline()
                data.append({'text': line, 'label': 1})
            for file in neg_files:
                path = os.path.join(data_dir, split, 'neg', file)
                with open(path) as f:
                    line = f.readline()
                data.append({'text': line, 'label': 0})
            random.shuffle(data)

        elif dataset == 'agnews':
            with open(os.path.join(data_dir, split + '.csv'),
                      encoding="utf-8") as csv_file:
                csv_reader = csv.reader(csv_file,
                                        quotechar='"',
                                        delimiter=",",
                                        quoting=csv.QUOTE_ALL,
                                        skipinitialspace=True)
                for i, row in enumerate(csv_reader):
                    label, title, description = row
                    label = int(label) - 1
                    text = ' [SEP] '.join((title, description))
                    data.append({'text': text, 'label': label})

        elif dataset == 'squad':
            with open(os.path.join(data_dir, split + '.json'),
                      'r',
                      encoding='utf-8') as reader:
                raw_data = json.load(reader)['data']
            for line in raw_data:
                for para in line['paragraphs']:
                    context = para['context']
                    for qa in para['qas']:
                        data.append({'context': context, 'qa': qa})

        elif dataset == 'newsqa':
            with gzip.GzipFile(os.path.join(data_dir, split + '.jsonl.gz'),
                               'r') as reader:
                content = reader.read().decode('utf-8').strip().split('\n')[1:]
                raw_data = [json.loads(line) for line in content]
            for line in raw_data:
                context = line['context']
                for qa in line['qas']:
                    data.append({'context': context, 'qa': qa})

        elif dataset in {'cnndm', 'msqg'}:
            src_file = os.path.join(data_dir, split + '.src')
            tgt_file = os.path.join(data_dir, split + '.tgt')
            with open(src_file) as f:
                src_data = [
                    line.strip().replace('<S_SEP>', '[SEP]') for line in f
                ]
            with open(tgt_file) as f:
                tgt_data = [
                    line.strip().replace('<S_SEP>', '[SEP]') for line in f
                ]
            for src, tgt in zip(src_data, tgt_data):
                data.append({'src': src, 'tgt': tgt})

        # split data
        logger.info(f'Spliting dataset {dataset} ({split})')
        splited_data = []
        n = len(data) // num_of_client
        data_idx = 0
        for i in range(num_of_client):
            num_split = n if i < num_of_client - 1 else \
                len(data) - n * (num_of_client - 1)
            cur_data = data[data_idx:data_idx + num_split]
            data_idx += num_split
            splited_data.append(cur_data)
        logger.info(f'Dataset {dataset} ({split}) is splited into '
                    f'{[len(x) for x in splited_data]}')

        return splited_data

    def _download_and_extract(self, dataset):
        url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com'
        os.makedirs(self.data_dir, exist_ok=True)
        download_url(f'{url}/{dataset}.zip', self.data_dir)

        raw_dir = os.path.join(self.data_dir, dataset + '_raw')
        extract_dir = os.path.join(self.data_dir, dataset)
        with zipfile.ZipFile(os.path.join(self.data_dir, f'{dataset}.zip'),
                             'r') as zip_ref:
            zip_ref.extractall(raw_dir)
        shutil.move(os.path.join(raw_dir, dataset), self.data_dir)
        if os.path.exists(os.path.join(extract_dir, '.DS_Store')):
            os.remove(os.path.join(extract_dir, '.DS_Store'))
        os.remove(os.path.join(self.data_dir, f'{dataset}.zip'))
        shutil.rmtree(raw_dir)


class SynthDataProcessor(object):
    def __init__(self, config, datasets):
        self.cfg = config
        if self.cfg.gaudi:
            self.device = HPUManager(
                hpu_available=self.cfg.use_gpu,
                specified_device=self.cfg.device).auto_choice()
        else:
            self.device = GPUManager(
                gpu_available=self.cfg.use_gpu,
                specified_device=self.cfg.device).auto_choice()
        self.pretrain_dir = config.federate.atc_load_from
        self.cache_dir = 'cache_debug' if \
            config.data.is_debug else config.data.cache_dir
        self.save_dir = os.path.join(self.cache_dir, 'synthetic')
        self.batch_size = config.data.hetero_synth_batch_size
        self.datasets = datasets
        self.num_clients = len(datasets)
        self.synth_prim_weight = config.data.hetero_synth_prim_weight
        self.synth_feat_dim = config.data.hetero_synth_feat_dim
        self.models = {}

    def save_data(self):
        if os.path.exists(self.save_dir):
            return

        max_sz, max_len = 1e8, 0
        for client_id in range(1, self.num_clients + 1):
            dataset = self.datasets[client_id -
                                    1]['train_contrast']['dataloader'].dataset
            max_sz = min(max_sz, len(dataset))
            max_len = max(max_len, len(dataset[0]['token_ids']))
        enc_hiddens = np.memmap(filename=os.path.join(self.cfg.outdir,
                                                      'tmp_feat.memmap'),
                                shape=(self.num_clients, max_sz, max_len,
                                       self.synth_feat_dim),
                                mode='w+',
                                dtype=np.float32)
        self._get_models()

        logger.info('Generating synthetic encoder hidden states')
        for client_id in tqdm(range(1, self.num_clients + 1)):
            dataloader = self.datasets[client_id -
                                       1]['train_contrast']['dataloader']
            model = self.models[client_id]
            model.eval()
            model.to(self.device)
            enc_hid = []
            for batch_i, data_batch in tqdm(enumerate(dataloader),
                                            total=len(dataloader)):
                token_ids = data_batch['token_ids']
                token_type_ids = data_batch['token_type_ids']
                attention_mask = data_batch['attention_mask']
                enc_out = model.model.encoder(
                    input_ids=token_ids.to(self.device),
                    attention_mask=attention_mask.to(self.device),
                    token_type_ids=token_type_ids.to(self.device),
                )
                enc_hid.append(enc_out.last_hidden_state.detach().cpu())

            enc_hid = torch.cat(enc_hid)
            if enc_hid.size(1) < max_len:
                enc_hid = torch.cat([
                    enc_hid,
                    torch.zeros(enc_hid.size(0), max_len - enc_hid.size(1),
                                self.synth_feat_dim)
                ],
                                    dim=1)
            enc_hiddens[client_id - 1] = enc_hid[:max_sz]
            model.to('cpu')

        all_hids = torch.from_numpy(enc_hiddens)
        prim_indices = [
            random.randint(0,
                           len(all_hids) - 1) for _ in range(len(all_hids[0]))
        ]  # avoid over-smooth results when setting
        # equal merging weights to all clients
        all_weights = torch.ones(len(all_hids), len(all_hids[0]))
        all_weights *= (1 - self.synth_prim_weight) / (len(all_hids) - 1)
        for i, pi in enumerate(prim_indices):
            all_weights[pi, i] = self.synth_prim_weight
        avg_hids = (all_hids * all_weights[:, :, None, None]).sum(0)

        logger.info('Generating synthetic input tokens')
        lm_head = self._get_avg_lm_head().to(self.device)
        with torch.no_grad():
            pred_toks = torch.cat([
                lm_head(avg_hids[i:i + self.batch_size].to(
                    self.device)).detach().cpu().argmax(dim=-1)
                for i in tqdm(range(0, avg_hids.size(0), self.batch_size))
            ])

        if self.cache_dir:
            logger.info('Saving synthetic data to \'{}\''.format(
                self.save_dir))
            os.makedirs(self.save_dir, exist_ok=True)
            saved_feats = np.memmap(
                filename=os.path.join(
                    self.save_dir,
                    'feature_{}.memmap'.format(self.synth_prim_weight)),
                shape=avg_hids.size(),
                mode='w+',
                dtype=np.float32,
            )
            saved_toks = np.memmap(
                filename=os.path.join(
                    self.save_dir,
                    'token_{}.memmap'.format(self.synth_prim_weight)),
                shape=pred_toks.size(),
                mode='w+',
                dtype=np.int64,
            )
            for i in range(len(avg_hids)):
                saved_feats[i] = avg_hids[i]
                saved_toks[i] = pred_toks[i]
            shapes = {'feature': avg_hids.size(), 'token': pred_toks.size()}
            with open(os.path.join(self.save_dir, 'shapes.json'), 'w') as f:
                json.dump(shapes, f)

        if os.path.exists(os.path.join(self.cfg.outdir, 'tmp_feat.memmap')):
            os.remove(os.path.join(self.cfg.outdir, 'tmp_feat.memmap'))

    def _get_models(self):
        for client_id in range(1, self.num_clients + 1):
            self.models[client_id] = self._load_model(ATCModel(self.cfg.model),
                                                      client_id)

    def _get_avg_lm_head(self):
        all_params = copy.deepcopy([
            self.models[k].lm_head.state_dict()
            for k in range(1, self.num_clients + 1)
        ])
        avg_param = all_params[0]
        for k in avg_param:
            for i in range(len(all_params)):
                local_param = all_params[i][k].float()
                if i == 0:
                    avg_param[k] = local_param / len(all_params)
                else:
                    avg_param[k] += local_param / len(all_params)
        avg_lm_head = copy.deepcopy(self.models[1].lm_head)
        avg_lm_head.load_state_dict(avg_param)
        return avg_lm_head

    def _load_model(self, model, client_id):
        global_dir = os.path.join(self.pretrain_dir, 'global')
        client_dir = os.path.join(self.pretrain_dir, 'client')
        global_ckpt_path = os.path.join(global_dir,
                                        'global_model_{}.pt'.format(client_id))
        client_ckpt_path = os.path.join(client_dir,
                                        'client_model_{}.pt'.format(client_id))
        if os.path.exists(global_ckpt_path):
            model_ckpt = model.state_dict()
            logger.info('Loading model from \'{}\''.format(global_ckpt_path))
            global_ckpt = torch.load(global_ckpt_path,
                                     map_location='cpu')['model']
            model_ckpt.update(global_ckpt)
            if os.path.exists(client_ckpt_path):
                logger.info(
                    'Updating model from \'{}\''.format(client_ckpt_path))
                client_ckpt = torch.load(client_ckpt_path,
                                         map_location='cpu')['model']
                model_ckpt.update(client_ckpt)
            model.load_state_dict(model_ckpt)
        return model
