import random
from torch.utils.data import Dataset, ConcatDataset, DataLoader
import numpy as np
import torch
from NLPTasks_wo_SuperGLUE import *
from tqdm import tqdm
import itertools
import os
import h5py
import math
import copy
from pathlib import Path
import re
import json
import nltk


def nltk_line_tokenizer(line):
    return nltk.word_tokenize(line)

class MetaH5Construction:

	def __init__(self, single_task, split, tokenizer):
		self.key_list = single_task.get_valid_keys()
		self.key2ids = {}
		for key in self.key_list:
			self.key2ids[key] = tokenizer(key.lower().capitalize() + ": ", return_tensors="np")['input_ids'][0, :-1].tolist()

		self.split_data_list = []
		for query_datapoint in tqdm(single_task.get_split(split)):
			data_dict = {}
			for key in self.key_list:
				value = single_task.get_value_from_key(query_datapoint, key)
				data_dict[key] = tokenizer(value, return_tensors="np")['input_ids'][0, :-1].tolist()
			self.split_data_list.append(data_dict)

def get_random_span(seq, l, n):
	indices = range(len(seq) - (l - 1) * n)
	result = []
	offset = 0
	for i in sorted(random.sample(indices, n)):
		i += offset
		result.append(i)
		offset += l - 1
	return result

def get_chunk_type(tag_name):
    tag_class = tag_name.split('-')[0]
    tag_type = tag_name.split('-')[-1]
    return tag_class, tag_type

def get_chunks(seq):
    default = "O"
    chunks = []

    chunk_type, chunk_start = None, None
    for i, tok in enumerate(seq):
        if tok == default and chunk_type is not None:
            chunk = (chunk_type, chunk_start, i)
            chunks.append(chunk)
            chunk_type, chunk_start = None, None

        elif tok != default:
            tok_chunk_class, tok_chunk_type = get_chunk_type(tok)
            if chunk_type is None:
                chunk_type, chunk_start = tok_chunk_type, i
            elif tok_chunk_type != chunk_type or tok_chunk_class == "B":
                chunk = (chunk_type, chunk_start, i)
                chunks.append(chunk)
                chunk_type, chunk_start = tok_chunk_type, i
        else:
            pass

    if chunk_type is not None:
        chunk = (chunk_type, chunk_start, len(seq))
        chunks.append(chunk)
    return chunks

def read_conll(file_path):
    file_path = Path(file_path)

    raw_text = file_path.read_text().strip()
    raw_docs = re.split(r'\n\t?\n', raw_text)

    data_list = []
    for doc in raw_docs:
        tokens = []
        tags = []
        for line in doc.split('\n'):
            items = line.split()
            if len(items) == 2:
                token, tag = items
                tokens.append(token)
                tags.append(tag)
        data_list.append((tokens, tags))

    return data_list

def read_cls_sen_data(_path):
	data_list = []
	with open(_path) as out:
		for l in out:
			l = l.strip()
			items = l.split('\t')
			if len(items) < 2: continue
			data_list.append((items[0], items[1]))
	return data_list

def read_pair_data(_path):
	data_list = []
	with open(_path) as out:
		for l in out:
			l = l.strip()
			items = l.split('\t')
			if len(items) < 3: continue
			data_list.append((items[0], items[1], items[2]))
	return data_list

class MetaH5NLPTaskDynamic(Dataset):

	def __init__(self, config, task, task_index, path_, split, tokenizer, max_length, is_train=True):
		self.split = split
		self.eos_token_id = tokenizer.eos_token_id
		self.sep_token_id = 0
		self.max_length = max_length
		self.use_t5_span = config.use_t5_span
		self.path_ = path_
		self.config = config
		self.is_train = is_train
		self.task = task
		self.tokenizer = tokenizer
		self._task_index = task_index
		self.nlu_only = self.config.nlu_only
		self.nlg_only = self.config.nlg_only

		self.key_to_index = None
		self.key_replacement_dict = None
		if len(config.key_replacement_path) > 0:
			self.key_replacement_dict = {}
			with open(config.key_replacement_path) as out:
				for idx, line in enumerate(out):
					if idx == 0: continue
					line = line.strip()
					items = line.split(',')
					key = self.get_query_key(items[0])
					if key not in self.key_replacement_dict:
						self.key_replacement_dict[key] = []
					for key_replacement in items:
						key_replacement = key_replacement.strip()
						if len(key_replacement) > 0:
							key_ids = tokenizer(key_replacement.lower().capitalize() + ": ", return_tensors="np")['input_ids'][0, :-1].tolist()
							self.key_replacement_dict[key].append(key_ids)


		assert not (self.nlg_only and self.nlu_only), "nlg_only and nlu_only are exclusive"

		with h5py.File(self.path_, 'r') as _hdf5:
			self.orginal_len = _hdf5['%s_instance_count' % split][0]
			self.data_len = min(self.config.meta_task_max_sample, self.orginal_len)
			self.train_data_index = [i for i in range(_hdf5['train_instance_count'][0])]
			self.nlu_key_list = self.task.get_nlu_keys()
			self.key_list = [key[7:] for key in _hdf5.keys() if key.startswith('keyids_') and key[7:] not in self.nlu_key_list]
			self.key_list.append('nlu_key')
			self.nlu_visible_keys = [key for key in self.key_list if not key == "nlu_key"]
			# self.nlu_key_list = [(k.lower(), self.key_to_index[k.lower()]) for k in self.nlu_key_list]
			# self.key_list = [(k.lower(), self.key_to_index[k.lower()]) for k in self.key_list]
			# self.key_list.append(('nlu_key', None))

			self.key_combinations = [[]]
			for i in range(1, len(self.key_list)):
				self.key_combinations += list(itertools.combinations(self.key_list, i))
			self.switch_prob = 1.0 / (2.0 - 2.0 * len(self.nlu_key_list) / len(self.key_combinations))
			
			self.nlg_only_key_combinations = []
			for key_conbination in self.key_combinations:
				if all([key in key_conbination for key in self.nlu_visible_keys]):
					continue
				self.nlg_only_key_combinations.append(key_conbination)

		self.mask_warmup_steps = config.mask_warmup_epochs * self.data_len if self.is_train else -1
		self.current_steps = 0

	def __len__(self):
		return self.data_len

	def _get_t5_span(self, original_sequence, span_start_index):
		if len(original_sequence) <= 10:
			input_seq = [span_start_index]
			output_seq = [span_start_index] + original_sequence
			return input_seq, output_seq, span_start_index + 1
		else:
			if len(original_sequence) <= 30 or span_start_index == 32097:
				span_count = 1
			elif len(original_sequence) <= 60 or span_start_index == 32098:
				span_count = 2
			else:
				span_count = 3

			span_length = math.ceil((len(original_sequence) * 0.15) / span_count)
			start_points = get_random_span(original_sequence, span_length, span_count)
			input_seq = []
			output_seq = []
			prev_start_point = 0
			for start_point in start_points:
				input_seq += original_sequence[prev_start_point:start_point] + [span_start_index]
				output_seq += [span_start_index] + original_sequence[start_point: start_point + span_length]
				span_start_index -= 1
				prev_start_point = start_point + span_length
			input_seq += original_sequence[prev_start_point:]
			return input_seq, output_seq, span_start_index

	def _get_instance_t5_span_representation(self, data_index, used_nlu_key):
		show_attributes = random.choice(self.key_combinations)

		input_ids = []
		output_ids = []
		span_index = 32099

		for key in self.key_list:
			if key == "nlu_key":
				real_key = used_nlu_key
			else:
				real_key = key

			input_ids += self.data_hdf5['keyids_%s' % real_key][0].tolist()
			value_ids = self.data_hdf5['%s_key_%s_values' % (self.split, real_key)][data_index].tolist()
			if key in show_attributes and span_index > 32096:
				input_ids += value_ids
			else:
				value_input_seq, value_output_seq, span_index = self._get_t5_span(value_ids, span_index)
				input_ids += value_input_seq
				output_ids += value_output_seq

		output_ids.append(span_index)
		output_ids.append(self.eos_token_id)

		return input_ids, output_ids, show_attributes, False

	def get_query_key(self, key):
		return ''.join(key.strip().lower().split())

	def curriculum_learning_mask(self, special_token_id, value_ids):
		if self.mask_warmup_steps <= 0 or self.current_steps >= self.mask_warmup_steps or len(value_ids) <= 10:
			return [special_token_id], [special_token_id] + value_ids
		else:
			mask_ratio = 0.15 + 0.85 * (self.current_steps / self.mask_warmup_steps)
			span_length = int(len(value_ids) * mask_ratio)
			start_point = get_random_span(value_ids, span_length, 1)[0]
			input_sub_ids = value_ids[:start_point] + [special_token_id] + value_ids[start_point + span_length:]
			output_sub_ids = [special_token_id] + value_ids[start_point: start_point + span_length]
			return input_sub_ids, output_sub_ids

	def _get_instance_representation(self, data_index, used_nlu_key, show_attributes=None, add_seperator=False, use_train_data=False, key_mapping=None):
		enable_nlu = False
		if show_attributes is None:
			if self.nlg_only:
				show_attributes = random.choice(self.nlg_only_key_combinations)
			else:
				if (not self.nlu_only) and random.random() < self.switch_prob:
					show_attributes = random.choice(self.key_combinations)
					enable_nlu = all([key in show_attributes for key in self.nlu_visible_keys])
				else:
					enable_nlu = True
					show_attributes = self.nlu_visible_keys

		input_ids = []
		output_ids = []
		miss_span_count = 0
		for key in self.key_list:
			if key == "nlu_key":
				real_key = used_nlu_key
			else:
				real_key = key

			if key_mapping is None:
				input_ids += self.data_hdf5['keyids_%s' % real_key][0].tolist()
			else:
				query_key = self.get_query_key(real_key)
				input_ids += key_mapping[query_key]
				
			value_ids = self.data_hdf5['%s_key_%s_values' % (self.split if not use_train_data else 'train', real_key)][data_index].tolist()
			value_ids = value_ids[:self.config.meta_task_max_value_length]
			if key in show_attributes:
				input_ids += value_ids
			else:
				special_token_id = 32099 - miss_span_count
				input_sub_ids, output_sub_ids = self.curriculum_learning_mask(special_token_id, value_ids)
				input_ids += input_sub_ids
				output_ids += output_sub_ids
				miss_span_count += 1

		if len(output_ids) > 0:
			output_ids.append(32099 - miss_span_count)
			output_ids.append(self.eos_token_id)

		if add_seperator:
			input_ids.append(self.sep_token_id)

		return input_ids, output_ids, show_attributes, enable_nlu

	def __getitem__(self, index):
		if self.data_len >= self.orginal_len:
			index = index % self.orginal_len
		else:
			scale = self.orginal_len // self.data_len
			min_v, max_v = index * scale, (index + 1) * scale - 1
			index = min(random.randint(min_v, max_v), self.data_len)

		if not hasattr(self, 'data_hdf5'):
			self.data_hdf5 = h5py.File(self.path_, 'r')

		if len(self.nlu_key_list) > 1:
			show_nlu_key = random.sample(self.nlu_key_list, k=1)[0]
		else:
			show_nlu_key = self.nlu_key_list[0]

		random.shuffle(self.key_list)

		key_mapping = None
		if self.key_replacement_dict is not None:
			key_mapping = {}
			for key in self.key_list + [show_nlu_key]:
				if key == 'nlu_key': continue
				query_key = self.get_query_key(key)
				key_mapping[query_key] = random.choice(self.key_replacement_dict[query_key])

		if self.use_t5_span:
			input_ids, output_ids, used_kv, enable_nlu = self._get_instance_t5_span_representation(index, show_nlu_key, add_seperator=True)
		else:
			input_ids, output_ids, used_kv, enable_nlu = self._get_instance_representation(index, show_nlu_key, add_seperator=True, key_mapping=key_mapping)

		saved_instances = []
		total_length = len(input_ids)

		for d_index in random.sample(self.train_data_index, k=16):
			if self.split == "train" and d_index == index: continue
			if total_length <= self.max_length - 1:
				full_example_ids, _, _, _ = self._get_instance_representation(d_index, show_nlu_key, show_attributes=self.key_list, add_seperator=True, use_train_data=True, key_mapping=key_mapping)
				if len(full_example_ids) + total_length <= self.max_length - 1:
					saved_instances.append(full_example_ids)
					total_length += len(full_example_ids)
			else:
				break

		if len(saved_instances) > 0:
			low_bound = 1
			# if not all key-value pairs are masked, it is possible not to include full example
			if len(used_kv) > 0: low_bound = 0
			selected_instance_num = random.choice([i for i in range(low_bound, len(saved_instances) + 1)])

			candidates_list = saved_instances[:selected_instance_num] + [input_ids]
			if self.config.shuffle_example:
				random.shuffle(candidates_list)

			final_input_ids = []
			for candidate in candidates_list:
				final_input_ids += candidate
			input_ids = final_input_ids
		
		input_ids[-1] = self.eos_token_id

		input_np = np.array(input_ids).astype(np.int64)[:self.max_length]
		output_np = np.array(output_ids).astype(np.int64)[:self.max_length]

		if self.config.enable_new_task_embeddings:
			_task_index = 0
		else:
			_task_index = self._task_index

		return input_np, output_np, _task_index, 0 if enable_nlu else 1


def process_tensor(tensor_list, last_dim, output_mask=False):
    tensor_len = [d.shape[0] for d in tensor_list]
    tensor_max_lenth = max(tensor_len)
    d_type = tensor_list[0].dtype
    if last_dim > 0:
        tensor_np = np.zeros((len(tensor_list), tensor_max_lenth, last_dim), dtype=d_type)
    else:
        tensor_np = np.zeros((len(tensor_list), tensor_max_lenth), dtype=d_type)
    mask_np = np.zeros((len(tensor_list), tensor_max_lenth), dtype=np.float32)
    for i, (d, l) in enumerate(zip(tensor_list, tensor_len)):
        if l > 0:
            tensor_np[i, :l] = d
            mask_np[i, :l] = 1
    if output_mask:
        return torch.from_numpy(tensor_np), torch.from_numpy(mask_np)
    else:
        return torch.from_numpy(tensor_np)

def _data_wrapper(dataset):
    encoder_input_ids, encoder_mask = process_tensor([d[0] for d in dataset], 0, output_mask=True)
    decoder_input_ids, decoder_mask = process_tensor([d[1] for d in dataset], 0, output_mask=True)
    decoder_input_ids[decoder_mask == 0] = -100
    gt_y, gt_x, data_index = None, None, None
    task_index = torch.tensor([0 for d in dataset]).long()
    task_type_index = torch.tensor([0 for d in dataset]).long()
    prefix_ids = torch.tensor([0 for d in dataset]).long()
    
    if len(dataset[0]) == 7:
    	data_index = [d[5] for d in dataset]
    	prefix_ids = torch.tensor([d[4] for d in dataset]).long()
    	task_index = torch.tensor([d[4] for d in dataset]).long()
    	task_type_index = torch.tensor([d[6] for d in dataset]).long()
    	gt_y = [d[3] for d in dataset]
    	gt_x = [d[2] for d in dataset]   	
    elif len(dataset[0]) == 6:
    	data_index = [d[5] for d in dataset]
    	prefix_ids = torch.tensor([d[4] for d in dataset]).long()
    	gt_y = [d[3] for d in dataset]
    	gt_x = [d[2] for d in dataset]
    elif len(dataset[0]) == 4:
    	task_index = torch.tensor([d[2] for d in dataset]).long()
    	task_type_index = torch.tensor([d[3] for d in dataset]).long()
    elif len(dataset[0]) == 3:
    	prefix_ids = torch.tensor([d[2] for d in dataset]).long()
    	

    return {"encoder_input_ids": encoder_input_ids, "encoder_mask": encoder_mask, "decoder_input_ids": decoder_input_ids, "task_ids": task_index, "task_type_ids": task_type_index, "prefix_ids": prefix_ids, "gt_x": gt_x, "gt_y": gt_y, "data_index": data_index}


def get_meta_task():
	task_list = []
	for task_name in TASK_NAME_TO_CLS:
		cls_method = TASK_NAME_TO_CLS[task_name]
		task_list.append(cls_method(task_name))
	return task_list

def get_meta_nlp_data(tokenizer, task_list, split, batch_size, max_length=512, shuffle=False, pre_load_training=False, distributed=False, is_root=True, is_train=True):
	meta_task_list = []
	updated_single_task = []
	for task_ in task_list:
		meta_task_list.append(MetaNLPTask(task_, split, tokenizer, max_length=max_length, pre_load_training=pre_load_training, is_train=is_train))
		updated_single_task.append(meta_task_list[-1].single_task)

	combined_dataset = ConcatDataset(meta_task_list)

	if is_root:
		print("%s Data Size %d" % (split, len(combined_dataset)))

	if distributed:
		dist_sampler = torch.utils.data.distributed.DistributedSampler(combined_dataset, shuffle=shuffle)
		dist_loader = DataLoader(combined_dataset, batch_size=batch_size, num_workers=4, collate_fn=_data_wrapper, sampler=dist_sampler)
		return dist_loader, updated_single_task
	else:
		data_loader = DataLoader(combined_dataset, batch_size=batch_size, num_workers=4, collate_fn=_data_wrapper, shuffle=shuffle)
		return data_loader, updated_single_task

def get_h5py_nlp_data(config, split, batch_size, tokenizer, max_length, use_t5_span=False, mask_warmup_steps=-1, shuffle=False, distributed=False, is_root=True, is_train=True):
	meta_task_list = []
	
	for task_name in TASK_NAME_TO_CLS:
		task_index = TASK_NAME_LIST.index(task_name)
		path_ = os.path.join("meta_task_h5_wo_SuperBLUE/", task_name + ".h5")
		task = TASK_NAME_TO_CLS[task_name](task_name)
		meta_task_list.append(MetaH5NLPTaskDynamic(config, task, task_index, path_, split, tokenizer, max_length, is_train=is_train))

	combined_dataset = ConcatDataset(meta_task_list)

	if is_root:
		print("Task Number %d" % len(TASK_NAME_TO_CLS))
		print("%s Data Size %d" % (split, len(combined_dataset)))

	if distributed:
		dist_sampler = torch.utils.data.distributed.DistributedSampler(combined_dataset, shuffle=shuffle)
		dist_loader = DataLoader(combined_dataset, pin_memory=True, batch_size=batch_size, num_workers=8, collate_fn=_data_wrapper, sampler=dist_sampler, drop_last=True)
		return dist_loader
	else:
		data_loader = DataLoader(combined_dataset, pin_memory=True, batch_size=batch_size, num_workers=8, collate_fn=_data_wrapper, shuffle=shuffle, drop_last=True)
		return data_loader

def get_h5py_nlp_dataset(config, split, tokenizer, max_length, use_t5_span=False, mask_warmup_steps=-1, is_root=True, is_train=True):
	meta_task_list = []
	
	for task_name in TASK_NAME_TO_CLS:
		task_index = TASK_NAME_LIST.index(task_name)
		path_ = os.path.join("meta_task_h5_wo_SuperBLUE/", task_name + ".h5")
		task = TASK_NAME_TO_CLS[task_name](task_name)
		meta_task_list.append(MetaH5NLPTaskDynamic(config, task, task_index, path_, split, tokenizer, max_length, is_train=is_train))

	combined_dataset = ConcatDataset(meta_task_list)

	if is_root:
		print("Task Number %d" % len(TASK_NAME_TO_CLS))
		print("%s Data Size %d" % (split, len(combined_dataset)))

	return combined_dataset

	



