from torch.utils.data import Dataset
from src.utils.util import *
from collections import Counter

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class MyDataset(Dataset):

    def __init__(self, tensor_dataset, num_ground_truth_labeled, num_ground_truth_unlabeled):
        self.tensor_dataset = tensor_dataset
        self.num_ground_truth_labeled = num_ground_truth_labeled
        self.num_ground_truth_unlabeled = num_ground_truth_unlabeled

    def __len__(self):
        return len(self.tensor_dataset)
    
    def __getitem__(self, idx):
        return self.tensor_dataset[idx]
    
class Data:
    
    def __init__(self, args, exp_dict):
        self.exp_dict = exp_dict
        set_seed(args.seed)
        max_seq_lengths = {'clinc':30, 'stackoverflow':45,'banking':55}
        args.max_seq_length = max_seq_lengths[args.dataset]

        processor = DatasetProcessor(exp_dict, args.label_type)
        self.data_dir = os.path.join(args.data_dir, args.dataset)
        #Scenarios for different label setups --> use exp_dict["label_type"] to specify the scenario
        #1. Original labels only --> This should apply to DAC and I guess as a baseline for our method
        #2. Original labels for known classes, and gpt3 labels for unknown classes
        #3. Gpt3 labels only
        if exp_dict["label_type"] == "original":
            self.all_label_list = processor.get_labels(self.data_dir, args.label_type)
            self.n_known_cls = round(len(self.all_label_list) * args.known_cls_ratio)
            self.known_label_list = list(np.random.choice(np.array(self.all_label_list), self.n_known_cls, replace=False))

        elif exp_dict["label_type"] == "gpt3_only" or exp_dict["label_type"] == "gpt3_top_k":
            #For this setup we want to only use the gpt3 labels for training
            #But we still need the GT labels for evaluation
            self.all_label_list = processor.get_labels(self.data_dir, args.label_type, exp_dict["top_k_gpt3_labels"])
            self.n_known_cls = round(len(self.all_label_list) * args.known_cls_ratio)
            self.known_label_list = list(np.random.choice(np.array(self.all_label_list), self.n_known_cls, replace=False))

            #Get the original labels for the label_list to add them in for the evaluation
            self.all_label_list = np.concatenate((self.all_label_list, processor.get_labels(self.data_dir, "original")))
            self.all_label_list = np.unique(self.all_label_list)

        elif exp_dict["label_type"] == "original_and_gpt3":
            #Need to get the original labels first and calculate the number of known classes
            self.all_label_list = processor.get_labels(self.data_dir, "original")
            self.n_known_cls = round(len(self.all_label_list) * args.known_cls_ratio)
            self.known_label_list = list(np.random.choice(np.array(self.all_label_list), self.n_known_cls, replace=False))

            self.all_label_list = np.concatenate((self.all_label_list, gpt3_label_list))
            self.all_label_list = np.unique(self.all_label_list)

        self.train_label_map = self.make_label_map(self.known_label_list)
        self.all_label_map = self.make_label_map(self.all_label_list)

        self.num_labels = int(len(self.all_label_list) * args.cluster_num_factor)

        self.train_labeled_examples, self.train_unlabeled_examples = self.get_examples(processor, args, 'train')
        print('num_labeled_samples',len(self.train_labeled_examples))
        print('num_unlabeled_samples',len(self.train_unlabeled_examples))
        self.eval_examples = self.get_examples(processor, args, 'eval')
        self.test_examples = self.get_examples(processor, args, 'test')

        #Get loader just structures the dataloader with the extracted examples to use
        self.train_labeled_dataloader = self.get_loader(self.train_labeled_examples, args, 'train')

        #The train_semi_dataloader has all of the training examples, labeled and unlabeled
        #Unlabeled examples are given a label id of -1
        #In the original Deep Aligned Clustering the semiloader is just used to get the features for the entire training set
        self.semi_input_ids, self.semi_input_mask, self.semi_segment_ids, self.semi_label_ids, self.semi_ground_truth = self.get_semi(self.train_labeled_examples, self.train_unlabeled_examples, args)
        self.train_semi_dataloader = self.get_semi_loader(self.semi_input_ids, self.semi_input_mask, self.semi_segment_ids, self.semi_label_ids, self.semi_ground_truth, args)
            
        #Get loader just structures the dataloader with the extracted examples to use
        self.eval_dataloader = self.get_loader(self.eval_examples, args, 'eval')
        self.test_dataloader = self.get_loader(self.test_examples, args, 'test')
        
    def get_examples(self, processor, args, mode='train'):
        ori_examples = processor.get_examples(self.data_dir, mode)
        
        if mode == 'train':
            train_labels = np.array([example.label.strip() for example in ori_examples])
            train_labeled_ids = []
            for label in self.known_label_list:
                num = round(len(train_labels[train_labels == label]) * args.labeled_ratio)
                pos = list(np.where(train_labels == label)[0])                
                train_labeled_ids.extend(random.sample(pos, num))

            train_labeled_examples, train_unlabeled_examples = [], []

            #TODO add assertion to check that all of the unlabeled examples do not match with the GT known labels
            

            for idx, example in enumerate(ori_examples):
                if idx in train_labeled_ids:
                    train_labeled_examples.append(example)
                else:
                    train_unlabeled_examples.append(example)

            return train_labeled_examples, train_unlabeled_examples
        
        elif mode == 'eval':
            return ori_examples
            #this original evaluation set was used for just evaluatiing the classification performance of the model
            #I've modified this for our version to not evaluate that accuracy and instead just opt to evaluate the discovery performance
            eval_examples = []
            for example in ori_examples:
                if example.label in self.known_label_list:
                    eval_examples.append(example)
            return eval_examples

        elif mode == 'test':
            return ori_examples

    def get_semi(self, labeled_examples, unlabeled_examples, args):
        
        tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)    
        labeled_features = convert_examples_to_features(labeled_examples, self.label_map, args.max_seq_length, tokenizer, args.label_type)
        unlabeled_features = convert_examples_to_features(unlabeled_examples, self.label_map, args.max_seq_length, tokenizer, args.label_type, unlabeled_features=True)

        labeled_input_ids = torch.tensor([f.input_ids for f in labeled_features], dtype=torch.long)
        labeled_input_mask = torch.tensor([f.input_mask for f in labeled_features], dtype=torch.long)
        labeled_segment_ids = torch.tensor([f.segment_ids for f in labeled_features], dtype=torch.long)
        labeled_label_ids = torch.tensor([f.label_id for f in labeled_features], dtype=torch.long)
        labeled_ground_truth = torch.tensor([f.ground_truth for f in labeled_features], dtype=torch.long)      

        unlabeled_input_ids = torch.tensor([f.input_ids for f in unlabeled_features], dtype=torch.long)
        unlabeled_input_mask = torch.tensor([f.input_mask for f in unlabeled_features], dtype=torch.long)
        unlabeled_segment_ids = torch.tensor([f.segment_ids for f in unlabeled_features], dtype=torch.long)
        if args.label_type == 'original_and_gpt3': #This is done to train with semi GT labels and GPT3 labels for unlabeled data
            unlabeled_label_ids = torch.tensor([f.label_id for f in unlabeled_features], dtype=torch.long)
        else:
            unlabeled_label_ids = torch.tensor([-1 for f in unlabeled_features], dtype=torch.long)
        unlabeled_ground_truth = torch.tensor([f.ground_truth for f in unlabeled_features], dtype=torch.long) 

        semi_input_ids = torch.cat([labeled_input_ids, unlabeled_input_ids])
        semi_input_mask = torch.cat([labeled_input_mask, unlabeled_input_mask])
        semi_segment_ids = torch.cat([labeled_segment_ids, unlabeled_segment_ids])
        semi_label_ids = torch.cat([labeled_label_ids, unlabeled_label_ids])
        semi_ground_truth = torch.cat([labeled_ground_truth, unlabeled_ground_truth])
        return semi_input_ids, semi_input_mask, semi_segment_ids, semi_label_ids, semi_ground_truth

    def get_semi_loader(self, semi_input_ids, semi_input_mask, semi_segment_ids, semi_label_ids, semi_ground_truth, args):
        semi_data = TensorDataset(semi_input_ids, semi_input_mask, semi_segment_ids, semi_label_ids, semi_ground_truth)
        #Iterate over the semi data and get the number of labeled ground truth examples

        num_labeled_gt = 0
        num_unlabeled_gpt3 = 0
        for i in range(len(semi_data)):
            if semi_data[i][4] == 1:
                num_labeled_gt += 1
            else:
                num_unlabeled_gpt3 += 1
        
        my_dataset = MyDataset(semi_data, num_labeled_gt, num_unlabeled_gpt3)
        semi_sampler = SequentialSampler(my_dataset)
        semi_dataloader = DataLoader(my_dataset, sampler=semi_sampler, batch_size = args.train_batch_size) 

        return semi_dataloader


    def get_loader(self, examples, args, mode = 'train'):
        tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)    
        
        if mode == 'train':
            features = convert_examples_to_features(examples, self.label_map, args.max_seq_length, tokenizer, args.label_type)
        elif mode == 'test' or mode == 'eval':
            features = convert_examples_to_features(examples, self.label_map, args.max_seq_length, tokenizer, args.label_type)

        input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
        segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
        label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
        ground_truth = torch.tensor([f.ground_truth for f in features], dtype=torch.long)
        data = TensorDataset(input_ids, input_mask, segment_ids, label_ids, ground_truth)
        
        if mode == 'train':
            sampler = RandomSampler(data)
            dataloader = DataLoader(data, sampler=sampler, batch_size = args.train_batch_size)    
        elif mode == 'eval' or mode == 'test':
            sampler = SequentialSampler(data)
            dataloader = DataLoader(data, sampler=sampler, batch_size = args.eval_batch_size) 
        
        return dataloader

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None, gpt3_label=None):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label
        self.gpt3_label = gpt3_label


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id, ground_truth):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id
        self.ground_truth = ground_truth


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""
    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, 'utf-8') for cell in line)
                lines.append(line)
            return lines

class DatasetProcessor(DataProcessor):

    def __init__(self, exp_dict, label_type='original'):
        self.label_type = label_type
        self.exp_dict = exp_dict

    def get_examples(self, data_dir, mode):
        if mode == 'train':
            return self._create_examples(self._read_tsv(os.path.join(data_dir, "train_gpt3_labels.tsv")), "train", self.label_type)
        elif mode == 'eval':
            return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), "train")
        elif mode == 'test':
            return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self, data_dir, label_type='original', top_k=50):
        """See base class."""
        import pandas as pd
        test = pd.read_csv(os.path.join(data_dir, "train_gpt3_labels.tsv"), sep="\t")
        if label_type == 'original':
            labels = np.unique(np.array(test['label']))
        elif label_type == 'gpt3_only':
            labels = np.array(test['intent'])
            labels = [label.strip() for label in labels]
            labels = np.unique(labels)
        elif label_type == 'gpt3_top_k':
            labels = test['intent'].value_counts()[test['intent'].value_counts() > top_k].index.tolist()
            labels = [label.strip() for label in labels]
            labels = np.unique(labels)

        return labels

    def _create_examples(self, lines, set_type, label_type='original'):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            if i == 0:
                continue
            guid = "%s-%s" % (set_type, i)
            text_a = line[0]
            original_label = line[1]

            if label_type == 'gpt3_only' or label_type == "gpt3_top_k":
                original_label = line[2].strip()
                gpt3_label = None
            elif label_type == 'original_and_gpt3':
                original_label = line[1].strip()
                gpt3_label = line[2].strip()
            else:
                gpt3_label = None

            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=original_label, gpt3_label=gpt3_label))
        return examples

def convert_examples_to_features(examples, label_map, max_seq_length, tokenizer, label_type='original', unlabeled_features=False):
    """Loads a data file into a list of `InputBatch`s."""

    features = []
    for (ex_index, example) in enumerate(examples):
        tokens_a = tokenizer.tokenize(example.text_a)

        tokens_b = None
        if example.text_b:
            tokens_b = tokenizer.tokenize(example.text_b)
            # Modifies `tokens_a` and `tokens_b` in place so that the total
            # length is less than the specified length.
            # Account for [CLS], [SEP], [SEP] with "- 3"
            _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            # Account for [CLS] and [SEP] with "- 2"
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[:(max_seq_length - 2)]

        # The convention in BERT is:
        # (a) For sequence pairs:
        #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
        #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
        # (b) For single sequences:
        #  tokens:   [CLS] the dog is hairy . [SEP]
        #  type_ids: 0   0   0   0  0     0 0
        #
        # Where "type_ids" are used to indicate whether this is the first
        # sequence or the second sequence. The embedding vectors for `type=0` and
        # `type=1` were learned during pre-training and are added to the wordpiece
        # embedding vector (and position vector). This is not *strictly* necessary
        # since the [SEP] token unambigiously separates the sequences, but it makes
        # it easier for the model to learn the concept of sequences.
        #
        # For classification tasks, the first vector (corresponding to [CLS]) is
        # used as as the "sentence vector". Note that this only makes sense because
        # the entire model is fine-tuned.
        tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
        segment_ids = [0] * len(tokens)

        if tokens_b:
            tokens += tokens_b + ["[SEP]"]
            segment_ids += [1] * (len(tokens_b) + 1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        if unlabeled_features and label_type == 'original_and_gpt3':
            label_id = label_map[example.gpt3_label]
            ground_truth = False
        else:
            # Check if we are using a 'True'
            if example.label in label_map:
                label_id = label_map[example.label]
                ground_truth = True
            elif label_type == 'gpt3_top_k':
                label_id = 99999
                ground_truth = False
            else:
                Exception("Label not found in label map")

        features.append(
            InputFeatures(input_ids=input_ids,
                          input_mask=input_mask,
                          segment_ids=segment_ids,
                          label_id=label_id,
                          ground_truth=ground_truth))
    return features


def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""
    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop(0)  # For dialogue context
        else:
            tokens_b.pop()
