from torch.utils.data import Dataset
from src.utils.util import *
from collections import Counter
from transformers import AutoTokenizer
import math
import random
import torch
import os
from torch.utils.data import RandomSampler, DataLoader, SequentialSampler

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class MyDataset(Dataset):

    def __init__(self, input_ids, input_masks, labels):
        self.input_ids = input_ids
        self.input_masks = input_masks
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "input_masks": self.input_masks[idx],
            "labels": self.labels[idx],
        }
    
class MyData:
    
    def __init__(self, args, exp_dict):
        self.exp_dict = exp_dict
        set_seed(exp_dict["seed"])

        if "train_file" in exp_dict:
            self.train_file = exp_dict["train_file"]
        else:
            self.train_file = "train_gpt3_labels_fixed.tsv"

        #Read all of the data into a pandas dataframe for the train, eval, and test sets
        self.train_df = pd.read_csv(os.path.join(args.data_dir, args.dataset, self.train_file), sep="\t")
        self.eval_df = pd.read_csv(os.path.join(args.data_dir, args.dataset, "eval_fixed.tsv"), sep="\t")
        self.test_df = pd.read_csv(os.path.join(args.data_dir, args.dataset, "test_fixed.tsv"), sep="\t")

        #Get a list of ground truth labels and gpt3 labels for the train set
        self.gt_label_list = self.train_df["label"].unique().tolist()
        self.gpt3_label_list = self.train_df["gpt3_label"].unique().tolist()
        #strip whitespace from gp3 labels
        self.gpt3_label_list = [x.strip() for x in self.gpt3_label_list]
        

        if exp_dict["label_type"] == "original":
            self.all_label_list = self.gt_label_list
            self.n_known_cls = round(len(self.all_label_list) * args.known_cls_ratio)
            self.known_label_list = list(np.random.choice(np.array(self.all_label_list), self.n_known_cls, replace=False))

        elif exp_dict["label_type"] == "gpt3_only" or exp_dict["label_type"] == "gpt3_top_k":

            #TODO reimplement the top k labels
            #Get the top k labels from the gpt3 labels

            self.n_known_cls = round(len(self.gpt3_label_list) * args.known_cls_ratio)
            self.known_label_list = list(np.random.choice(np.array(self.gpt3_label_list), self.n_known_cls, replace=False))

            #Get the original labels for the label_list to add them in for the evaluation
            self.all_label_list = self.gt_label_list + self.gpt3_label_list
            self.all_label_list = np.unique(self.all_label_list)

        elif exp_dict["label_type"] == "original_and_gpt3":
            #Need to get the original labels first and calculate the number of known classes
            self.all_label_list = self.gt_label_list + self.gpt3_label_list
            self.all_label_list = np.unique(self.all_label_list)
            self.n_known_cls = round(len(self.gt_label_list) * args.known_cls_ratio) #only use the GT labels for known classes
            self.known_label_list = list(np.random.choice(np.array(self.gt_label_list), self.n_known_cls, replace=False))

        self.train_label_map = self.make_label_map(self.known_label_list)
        self.all_label_map = self.make_label_map(self.all_label_list)

        self.train_num_labels = int(len(self.known_label_list) * args.cluster_num_factor)
        self.num_labels = int(len(self.all_label_list) * args.cluster_num_factor)

        #Get the examples from the dataframe
        self.train_examples = self.get_input_examples(self.train_df, exp_dict, 'train')
        self.eval_examples = self.get_input_examples(self.eval_df, exp_dict, 'eval')
        self.test_examples = self.get_input_examples(self.test_df, exp_dict, 'test')

        self.train_labeled_examples, self.train_unlabeled_examples = self.get_labeled_and_unlabeled_examples(self.train_examples, exp_dict["label_type"])

        print('num_labeled_samples',len(self.train_labeled_examples))
        print('num_unlabeled_samples',len(self.train_unlabeled_examples))

        assert len(self.train_labeled_examples) + len(self.train_unlabeled_examples) == len(self.train_examples)

        #Get the dataloaders for the train, eval, and test sets
        if exp_dict["label_type"] == "original_and_gpt3":
            self.train_labeled_examples = self.train_labeled_examples + self.train_unlabeled_examples
            self.train_labeled_loader = self.get_dataloader(self.train_labeled_examples, exp_dict, 'train', exp_dict["label_type"])
            self.train_unlabeled_loader = None
        else:
            self.train_labeled_loader = self.get_dataloader(self.train_labeled_examples, exp_dict, 'train', exp_dict["label_type"])
            if len(self.train_unlabeled_examples) > 0:
                self.train_unlabeled_loader = self.get_dataloader(self.train_unlabeled_examples, exp_dict, 'unlabeled', exp_dict["label_type"])
            else:
                self.train_unlabeled_loader = None
        self.eval_dataloader = self.get_dataloader(self.eval_examples, exp_dict, 'eval', exp_dict["label_type"])
        self.test_dataloader = self.get_dataloader(self.test_examples, exp_dict, 'test', exp_dict["label_type"])


    def make_label_map(self, label_list):
        label_map = {}
        for i, label in enumerate(label_list):
            label_map[label] = i
        return label_map

    def get_input_examples(self, dataframe, exp_dict, mode='train'):
        """Gets a collection of `InputExample`s for the the corresponding dataset."""
        #Iterate over the dataframe rows and create an InputExample for each row
        examples = []
        for i, row in dataframe.iterrows():
            guid = i
            text_a = row["text"]
            label = row["label"].strip()
            gpt3_label = None
            if mode=='train':
                gpt3_label = row["gpt3_label"].strip()
            examples.append(InputExample(guid=guid, text_a=text_a, label=label, gpt3_label=gpt3_label))

        return examples

    def get_labeled_and_unlabeled_examples(self, examples, label_type='original'):

        if label_type == "gpt3_only" or label_type == "gpt3_top_k":
            train_labels = np.array([example.gpt3_label.strip() for example in examples])
        else:
            train_labels = np.array([example.label for example in examples])

        train_labeled_ids = []
        for label in self.known_label_list:
            num_labels = len(train_labels[train_labels == label])
            if num_labels == 1:
                num = 1
            else:
                num = int(round(len(train_labels[train_labels == label]) * self.exp_dict["labeled_ratio"]))
            pos = list(np.where(train_labels == label)[0])                
            train_labeled_ids.extend(random.sample(pos, num))

        train_labeled_examples, train_unlabeled_examples = [], []

        for idx, example in enumerate(examples):
            if idx in train_labeled_ids:
                if label_type == "gpt3_only" or label_type == "gpt3_top_k":
                    example.use_gt_label = False
                train_labeled_examples.append(example)
            else:
                example.use_gt_label = False
                train_unlabeled_examples.append(example)

        return train_labeled_examples, train_unlabeled_examples

    def get_dataloader(self, examples, exp_dict, mode='train', label_type='original'):
        """
        Tokenize the examples and create a dataloader for them
        #Modes: train_labeled, gpt3_labeled, unlabeled, eval, test
        """
        self.tokenizer = tokenizer = AutoTokenizer.from_pretrained(exp_dict["model_name"])
        
        input_ids = []
        input_mask_array = []
        label_ids = []

        for example in examples:

            encoded_sent = tokenizer.encode(
                    example.text_a,
                    add_special_tokens=True,
                    max_length=exp_dict["max_seq_length"],
                    padding="max_length",
                    truncation=True,
                )
            if mode=='train':
                if example.use_gt_label == False:
                    if label_type == "original_and_gpt3":
                        label_id = self.all_label_map[example.gpt3_label] #If we are to use the gpt3 label
                    else:
                        label_id = self.train_label_map[example.gpt3_label]
                else:
                    label_id = self.train_label_map[example.label]
            elif mode == 'eval' or mode == 'test':
                label_id = self.all_label_map[example.label]
            else:
                label_id = -1

            input_ids.append(encoded_sent)
            label_ids.append(label_id)
            
            
        # Attention to token (to ignore padded input wordpieces)
        for sent in input_ids:
            att_mask = [int(token_id > 0) for token_id in sent]
            input_mask_array.append(att_mask)

        
        input_ids = torch.tensor(input_ids)
        input_masks = torch.tensor(input_mask_array)
        label_ids = torch.tensor(label_ids, dtype=torch.long)

        return self.get_loader(input_ids, input_masks, label_ids, mode)

    def get_loader(self, input_ids, input_masks, label_ids, mode='train'):

        dataset = MyDataset(input_ids, input_masks, label_ids)
        
        if mode=='train' or mode=='unlabeled':
            sampler = RandomSampler(dataset)
            dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.exp_dict["batch_size"])    
        elif mode=='eval' or mode=='test':
            sampler = SequentialSampler(dataset)
            dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.exp_dict["batch_size"]) 
        
        return dataloader

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, label=None, gpt3_label=None):
        """Constructs a InputExample.

        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.label = label
        self.gpt3_label = gpt3_label
        self.use_gt_label = True