import argparse, os, sys, random, logging
from collections import defaultdict as ddict, Counter

import numpy as np
from pickle5 import pickle
from tqdm import tqdm

import torch
from transformers import AutoTokenizer

sys.path.append(os.path.join(sys.path[0], '..'))
from src.utils.data import dataset_info, data_keys
from src.utils.dataset_builder import builder_dict

logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s')
logger = logging.getLogger(__name__)


def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def stratified_sampling(x, n_samples, stratify):
    n_total = x.shape[0]
    assert sum(stratify) == n_total
    
    n_strat_samples = [int(i*n_samples/n_total) for i in stratify]
    cum_n_samples = np.cumsum([0]+list(stratify))
    sampled_idcs = []
    for i, n_strat_sample in enumerate(n_strat_samples):
        sampled_idcs.append(np.random.choice(range(cum_n_samples[i], cum_n_samples[i+1]), 
                                            replace=False, 
                                            size=n_strat_sample))
        
    n_current_samples = sum(n_strat_samples)
    if  n_current_samples < n_samples:
        delta_n_samples = n_samples - n_current_samples
        sampled_idcs.append(np.random.choice(range(n_total), replace=False, size=delta_n_samples))
        
    sampled_idcs = np.concatenate(sampled_idcs)
    samples = x[sampled_idcs, ...]
    
    return samples, sampled_idcs

def sample_dataset(data_path, dataset_dict, split, num_samples, seed, io_mode):
    sampled_split_filename = f'{split}_split_{num_samples}_{seed}.pkl'
    data_path = os.path.split(data_path)[0]
    if os.path.exists(os.path.join(data_path, sampled_split_filename)):
        with open(os.path.join(data_path, sampled_split_filename), 'rb') as f:
            sampled_split = pickle.load(f)
    else:
        assert io_mode == 'I-O'
        if 'esnli' in data_path:
            labels = dataset_dict['label']
            label_counts = list(Counter(labels).values())
            _, sampled_split = stratified_sampling(torch.tensor(labels), num_samples, label_counts)
        else:
            sampled_split = torch.randperm(len(dataset_dict['item_idx']))[:num_samples].numpy()

        sampled_split = list(sampled_split)
        with open(os.path.join(data_path, sampled_split_filename), 'wb') as f:
            pickle.dump(sampled_split, f)
    
    for key in data_keys[args.arch]:
        dataset_dict[key] = sampled_split if key == 'item_idx' else [dataset_dict[key][i] for i in sampled_split]

    return dataset_dict

def load_dataset(data_path, keys, split, num_samples, seed):
    dataset_dict = ddict(list)
    for key in tqdm(keys, desc=f'Loading {split} dataset'):
        filename = f'{key}.pkl' if num_samples is None else f'{key}_{num_samples}_{seed}.pkl'
        with open(os.path.join(data_path, filename), 'rb') as f:
            dataset_dict[key] = pickle.load(f)
    return dataset_dict

def save_dataset(data_path, dataset_dict, split, num_samples, seed):
    save_keys = data_keys[args.arch]
    for key in tqdm(save_keys, desc=f'Saving {split} dataset'):
        filename = f'{key}_{num_samples}_{seed}.pkl' if num_samples is not None else f'{key}.pkl'
        with open(os.path.join(data_path, filename), 'wb') as f:
            pickle.dump(dataset_dict[key], f)

def main(args):
    set_random_seed(args.seed)
    for arg_key, arg_val in vars(args).items():
        if arg_key == 'num_samples':
            assert arg_val is None or arg_val >= 1
        elif arg_key == 'rationale_src':
            if arg_val is None:
                assert args.io_mode == 'I-O'
            elif arg_val != 'gold':
                assert arg_val == 'gpt-neox' # Currently, only gpt-neox rationales are available for multiple datasets
                assert args.dataset in ['ecqa', 'ecqa_unk', 'openbookqa', 'strategyqa', 'qasc', 'eqasc']
        else:
            assert arg_val is not None

    split, num_examples = dataset_info[args.dataset][args.split]
    if args.num_samples is not None:
        assert args.num_samples < num_examples

    tokenizer = AutoTokenizer.from_pretrained(args.arch, model_max_length=args.model_max_length)

    data_path = os.path.join(args.data_dir, args.dataset, args.arch, args.split, args.io_mode)
    if args.rationale_src not in [None, 'gold']:
        data_path = os.path.join(data_path, args.rationale_src)

    if not os.path.exists(data_path):
        os.makedirs(data_path)

    missing_data_keys = [x for x in data_keys[args.arch] if not os.path.exists(os.path.join(data_path, f'{x}.pkl'))]
  
    dataset_builder = builder_dict[args.dataset](args.dataset, args.split, split, args.io_mode, args.rationale_src, tokenizer)

    if args.num_samples is None and missing_data_keys:
        dataset_dict = dataset_builder.process_instances()

    elif args.num_samples is not None:
        assert not missing_data_keys

        missing_sample_data_keys = [x for x in data_keys[args.arch] if not os.path.exists(os.path.join(data_path, f'{x}_{args.num_samples}_{args.seed}.pkl'))]
        assert missing_sample_data_keys

        dataset_dict = load_dataset(data_path, data_keys[args.arch], args.split, None, args.seed)
        dataset_dict = sample_dataset(data_path, dataset_dict, args.split, args.num_samples, args.seed, args.io_mode)
    
    save_dataset(data_path, dataset_dict, args.split, args.num_samples, args.seed)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Dataset preprocessing.')
    parser.add_argument('--data_dir', type=str, default='../data/', help='Root directory for datasets.')
    parser.add_argument('--dataset', type=str, choices=['esnli',
                                                        'ecqa', 'ecqa_unk',
                                                        'openbookqa',
                                                        'strategyqa',
                                                        'creak',
                                                        'qasc', 'qasc_unk', 'eqasc',
                                                        'quartz',
                                                        'aqua_rat', 'aqua_rat_unk',
                                                        'winowhy_0', 'winowhy_1', 'winowhy_2', 'winowhy_3', 'winowhy_4',
                                                        'comve', 'comve_unk',
                                                        'mnli_matched', 'mnli_mismatched',
                                                        'anli_r1', 'anli_r2', 'anli_r3'])
    parser.add_argument('--arch', type=str, choices=['t5-small', 't5-base', 't5-large', 't5-3b'])
    parser.add_argument('--model_max_length', type=int, default=512, help='Maximum model input length.')
    parser.add_argument('--split', type=str, help='Dataset split.', choices=['train', 'dev', 'test'])
    parser.add_argument('--inhouse_split_seed', type=int, default=0, help='Random seed for inhouse split.')
    parser.add_argument('--num_samples', type=int, help='Number of examples to sample. None means all available examples are used.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--io_mode', type=str, help='Model I/O mode.', choices=['I-O', 'IR-O', 'I-OR', 'I-RO', 'IshuffledR-O', 'IreplacedR-O'])
    parser.add_argument('--rationale_src', type=str, help='Rationale source.', choices=[None, 'gold', 'gpt-neox', 'gpt-2', 'gpt-j', 'gpt-3'])
    args = parser.parse_args()
    main(args)