import sys
import os
import numpy as np
import pandas as pd
import random

import torch
from torch.utils.data import TensorDataset, RandomSampler, SequentialSampler, DataLoader

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

class Data:
    def __init__(self, args, tokenizer):
        
        set_seed(args.seed)
        
        max_seq_lengths = {'oos':30, 'stackoverflow':55,'banking':55}
        
        self.bert_tokenizer = tokenizer
        self.max_length = max_seq_lengths[args.dataset]
        self.train_batch_size = args.train_batch_size
        self.test_batch_size = args.test_batch_size
        
        self.known_class_ratio = args.known_class_ratio

        self.data_dir = f'../Adaptive-Decision-Boundary/data/{args.dataset}/'
        self.all_label_list = self.get_labels() # type -> string
        self.all_label_id = self.make_labels_id(self.all_label_list) # type -> dict

        self.n_known_label = round(len(self.all_label_list) * self.known_class_ratio) # type -> int
        
        self.known_label_list = list(np.random.choice(np.array(self.all_label_list), self.n_known_label, replace=False)) # type -> list, string
        self.known_label_id = [self.all_label_id[i] for i in self.known_label_list] #type -> list
        
        # add unknown label
        self.unknown_label_id = len(self.all_label_list)

        if args.dataset == "oos":
            self.all_label_id['ooi'] = self.unknown_label_id # out-of-intent
        else:
            self.all_label_id['oos'] = self.unknown_label_id
        
        self.num_known_label = len(self.known_label_list) # type -> int
        
        # data load & make example
        train_data = self.read_dataset('train')
        valid_data = self.read_dataset('valid')
        test_data = self.read_dataset('test')
        
        self.train_dataset = self.get_dataloader(self.convert_to_example(train_data, mode='train'), 'train')
        self.valid_dataset = self.get_dataloader(self.convert_to_example(valid_data, mode='valid'), 'valid')
        self.test_dataset = self.get_dataloader(self.convert_to_example(test_data, mode='test'), 'test')
        
        
    def get_labels(self):
        data = pd.read_csv(os.path.join(self.data_dir, "train.tsv"), sep="\t")
        labels = np.unique(np.array(data['label']))

        return labels

    def make_labels_id(self, label_list):
        label_dict = {}
        for idx, _label in enumerate(label_list):
            label_dict[_label] = idx

        return label_dict
    
    def read_dataset(self, mode):
        if mode == 'train':
            data = pd.read_csv(os.path.join(self.data_dir, 'train.tsv'), sep='\t')
        elif mode == 'valid':
            data = pd.read_csv(os.path.join(self.data_dir, 'dev.tsv'), sep='\t')
        elif mode == 'test':
            data = pd.read_csv(os.path.join(self.data_dir, 'test.tsv'), sep='\t')

        return data
    
    def convert_to_example(self, data, mode):

        list_of_examples = []
        if (mode == 'train') or (mode == 'valid'): # or (mode == 'test'):
            for _, i in data.iterrows():
                i = i.tolist() # ['sentence', label]

                if self.all_label_id[i[-1]] in self.known_label_id:
                    i[-1] = self.all_label_id[i[-1]]
                    
                    inputs = self.bert_tokenizer(i[0], return_tensors='pt', truncation=True, max_length=self.max_length, pad_to_max_length=True, add_special_tokens=True)
                    
                    # inputs에는 input_ids, token_type_ids, attention_mask 가 key이고, 각 값들이 value임
                    # value들은 one sample에서, 각각 [1, max_seq_length] 을 가지며, type은 torch.tensor임
                    list_of_examples.append(InputFeatures(input_ids=inputs['input_ids'].squeeze(0).tolist(), attention_mask=inputs['attention_mask'].squeeze(0).tolist(), label=i[-1]))

        elif mode == 'test':
            for _, i in data.iterrows():
                i = i.tolist() # ['sentence', label]
                if self.all_label_id[i[-1]] in self.known_label_id:
                    i[-1] = self.all_label_id[i[-1]]

                else:
                    i[-1] = self.unknown_label_id

                inputs = self.bert_tokenizer(i[0], return_tensors='pt', truncation=True, max_length=self.max_length, pad_to_max_length=True, add_special_tokens=True)
                list_of_examples.append(InputFeatures(input_ids=inputs['input_ids'].squeeze(0).tolist(), attention_mask=inputs['attention_mask'].squeeze(0).tolist(), label=i[-1]))

        return list_of_examples
                    
    def get_dataloader(self, features, mode):
        input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        label = torch.tensor([f.label for f in features], dtype=torch.long)
        
        datatensor = TensorDataset(input_ids, attention_mask, label)
        
        if mode == 'train':
            sampler = RandomSampler(datatensor)
            dataloader = DataLoader(datatensor, sampler=sampler, batch_size=self.train_batch_size)
        
        elif (mode == 'valid') or (mode == 'test'):
            sampler = SequentialSampler(datatensor)
            dataloader = DataLoader(datatensor, sampler=sampler, batch_size=self.test_batch_size)

        return dataloader

class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, attention_mask, label):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.label = label