""" Fine-tuning on A Classification Task with pretrained Transformer """

import itertools
import csv
import fire

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import tokenization
import bert
import optim
import train
from scipy import stats
from utils import set_seeds, get_device, truncate_tokens_pair

class CsvDataset(Dataset):
    """ Dataset Class for CSV file """
    labels = None

    def __init__(self, file, pipeline=[]): # cvs file and pipeline object
        Dataset.__init__(self)
        data = []
        with open(file, "r") as f:
            # list of splitted lines : line is also list
            lines = csv.reader(f, delimiter='\t', quotechar=None)
            for instance in self.get_instances(lines): # instance : tuple of fields
                for proc in pipeline: # a bunch of pre-processing
                    instance = proc(instance)
                data.append(instance)

        # To Tensors
        flag = False
        if 'STS-B' in file:
            flag = True
        self.tensors = [torch.tensor(x, dtype=torch.float) if (i==3 and flag) else torch.tensor(x, dtype=torch.long) for i, x in enumerate(zip(*data))  ]
        # print(self.tensors[0])
        # print(self.tensors[1])
        # print(self.tensors[2])
        # print(self.tensors[3])

    def __len__(self):
        return self.tensors[0].size(0)

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def get_instances(self, lines):
        """ get instance array from (csv-separated) line list """
        raise NotImplementedError


class MRPC(CsvDataset):
    """ Dataset class for MRPC """
    labels = ("0", "1") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 1, None): # skip header
            yield line[0], line[3], line[4] # label, text_a, text_b

class STSB(CsvDataset):
    """ Dataset class for MNLI """
    labels = ("0") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 1, None): # skip header
            yield line[-1], line[7], line[8] # label, text_a, text_b


class CoLA(CsvDataset):
    labels = ("0", "1") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 0, None): # skip header
            yield line[1], line[3], None # label, text_a, text_b


class RTE(CsvDataset):
    labels = ("not_entailment", "entailment") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 1, None): # skip header
            yield line[-1], line[1], line[2] # label, text_a, text_b


class WNLI(CsvDataset):
    labels = ("0", "1") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 1, None): # skip header
            yield line[-1], line[1], line[2] # label, text_a, text_b


class SST2(CsvDataset):
    """ Dataset class for MNLI """
    labels = ("0", "1") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 1, None): # skip header
            yield line[1], line[0], None # label, text_a, text_b


class QNLI(CsvDataset):
    """ Dataset class for MNLI """
    labels = ("not_entailment", "entailment") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 1, None): # skip header
            yield line[-1], line[1], line[2] # label, text_a, text_b

class QQP(CsvDataset):
    """ Dataset class for MNLI """
    labels = ("0", "1") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 1, None): # skip header
            yield line[5], line[3], line[4] # label, text_a, text_b

class MNLI(CsvDataset):
    """ Dataset class for MNLI """
    labels = ("contradiction", "entailment", "neutral") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in itertools.islice(lines, 1, None): # skip header
            yield line[-1], line[8], line[9] # label, text_a, text_b


def dataset_class(task):
    """ Mapping from task string to Dataset Class """
    table = {'mrpc': MRPC, 'stsb':STSB, 'cola':CoLA, 'rte':RTE, 'wnli':WNLI, 'sst2': SST2, 'qnli':QNLI, 'qqp':QQP, 'mnli': MNLI}
    return table[task]


class Pipeline():
    """ Preprocess Pipeline Class : callable """
    def __init__(self):
        super().__init__()

    def __call__(self, instance):
        raise NotImplementedError


class Tokenizing(Pipeline):
    """ Tokenizing sentence pair """
    def __init__(self, preprocessor, tokenize):
        super().__init__()
        self.preprocessor = preprocessor # e.g. text normalization
        self.tokenize = tokenize # tokenize function

    def __call__(self, instance):
        label, text_a, text_b = instance

        label = self.preprocessor(label)
        tokens_a = self.tokenize(self.preprocessor(text_a))
        tokens_b = self.tokenize(self.preprocessor(text_b)) \
                   if text_b else []

        return (label, tokens_a, tokens_b)


class AddSpecialTokensWithTruncation(Pipeline):
    """ Add special tokens [CLS], [SEP] with truncation """
    def __init__(self, max_len=512):
        super().__init__()
        self.max_len = max_len

    def __call__(self, instance):
        label, tokens_a, tokens_b = instance

        # -3 special tokens for [CLS] text_a [SEP] text_b [SEP]
        # -2 special tokens for [CLS] text_a [SEP]
        _max_len = self.max_len - 3 if tokens_b else self.max_len - 2
        truncate_tokens_pair(tokens_a, tokens_b, _max_len)

        # Add Special Tokens
        tokens_a = ['[CLS]'] + tokens_a + ['[SEP]']
        tokens_b = tokens_b + ['[SEP]'] if tokens_b else []

        return (label, tokens_a, tokens_b)


class TokenIndexing(Pipeline):
    """ Convert tokens into token indexes and do zero-padding """
    def __init__(self, indexer, labels, task, max_len=512):
        super().__init__()
        self.indexer = indexer # function : tokens to indexes
        # map from a label name to a label index
        self.label_map = {name: i for i, name in enumerate(labels)}
        self.max_len = max_len
        self.task = task

    def __call__(self, instance):
        label, tokens_a, tokens_b = instance

        input_ids = self.indexer(tokens_a + tokens_b)
        segment_ids = [0]*len(tokens_a) + [1]*len(tokens_b) # token type ids
        input_mask = [1]*(len(tokens_a) + len(tokens_b))

        if self.task != 'stsb':
            label_id = self.label_map[label]
        else:
            label_id = 0.2 * float(label)
        # zero padding
        n_pad = self.max_len - len(input_ids)
        input_ids.extend([0]*n_pad)
        segment_ids.extend([0]*n_pad)
        input_mask.extend([0]*n_pad)

        return (input_ids, segment_ids, input_mask, label_id)


class Classifier(nn.Module):
    """ Classifier with Transformer """
    def __init__(self, cfg, n_labels):
        super().__init__()
        self.transformer = models.Transformer(cfg)
        self.fc = nn.Linear(cfg.dim, cfg.dim)
        self.activ = nn.Tanh()
        self.drop = nn.Dropout(cfg.p_drop_hidden)
        self.classifier = nn.Linear(cfg.dim, n_labels)

    def forward(self, input_ids, segment_ids, input_mask):
        h = self.transformer(input_ids, segment_ids, input_mask)
        # only use the first h in the sequence
        pooled_h = self.activ(self.fc(h[:, 0]))
        logits = self.classifier(self.drop(pooled_h))
        return logits

class RegreClassifier(nn.Module):
    """ Classifier with Transformer """
    def __init__(self, cfg, n_labels):
        super().__init__()
        self.transformer = models.Transformer(cfg)
        self.fc = nn.Linear(cfg.dim, cfg.dim)
        self.activ = nn.Tanh()
        self.drop = nn.Dropout(cfg.p_drop_hidden)
        self.classifier = nn.Linear(cfg.dim, 1)
        self.activ2 = nn.Sigmoid()

    def forward(self, input_ids, segment_ids, input_mask):
        h = self.transformer(input_ids, segment_ids, input_mask)
        # only use the first h in the sequence
        pooled_h = self.activ(self.fc(h[:, 0]))
        logits = self.activ2(self.classifier(self.drop(pooled_h)))
        return logits


def main(task='mrpc',
         train_cfg='config/train_mrpc.json',
         model_cfg='config/bert_base.json',
         data_file='../glue/MRPC/train.tsv',
         model_file=None,
         pretrain_file='../uncased_L-12_H-768_A-12/bert_model.ckpt',
         data_parallel=True,
         vocab='data/vocab.txt',
         save_dir='model/finetune',
         max_len=128,
         mode='train',
         lr=None
         ):

    cfg = train.Config.from_json(train_cfg)
    if lr:
        print(cfg.lr)
        cfg = cfg._replace(lr=lr)
        print(cfg.lr)
    model_cfg = models.Config.from_json(model_cfg)


    set_seeds(cfg.seed)

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab, do_lower_case=True)
    TaskDataset = dataset_class(task) # task dataset class according to the task
    pipeline = [Tokenizing(tokenizer.convert_to_unicode, tokenizer.tokenize),
                AddSpecialTokensWithTruncation(max_len),
                TokenIndexing(tokenizer.convert_tokens_to_ids,
                              TaskDataset.labels, task, max_len)]
    dataset = TaskDataset(data_file, pipeline)
    data_iter = DataLoader(dataset, batch_size=cfg.batch_size, shuffle=True)



    print('label num = {a}'.format(a=len(TaskDataset.labels)))
    if task != 'stsb':
        criterion = nn.CrossEntropyLoss()
        model = Classifier(model_cfg, len(TaskDataset.labels))
    else:
        criterion = nn.MSELoss()
        model = RegreClassifier(model_cfg, len(TaskDataset.labels))


    trainer = train.Trainer(cfg,
                            model,
                            data_iter,
                            optim.optim4GPU(cfg, model),
                            save_dir, get_device())

    if mode == 'train':
        def get_loss(model, batch, global_step): # make sure loss is a scalar tensor
            input_ids, segment_ids, input_mask, label_id = batch
            logits = model(input_ids, segment_ids, input_mask)
            if task == 'stsb':
                logits = logits.view(-1)
            loss = criterion(logits, label_id)
            return loss

        trainer.train(get_loss, model_file, pretrain_file, data_parallel)
        return None

    elif mode == 'eval':
        # def evaluate(model, batch):
        #     input_ids, segment_ids, input_mask, label_id = batch
        #     logits = model(input_ids, segment_ids, input_mask)
        #     _, label_pred = logits.max(1)
        #     result = (label_pred == label_id).float() #.cpu().numpy()
        #     accuracy = result.mean()
        #     return accuracy, result

        results, labels = trainer.eval(model_file, data_parallel)
        total_accuracy = score_func(task, results, labels)
        #total_accuracy = torch.cat(results).mean().item()
        print('Accuracy:', total_accuracy)
        return total_accuracy


def score_func(task, preds, labels):
    # acc
    if task in ['sst2', 'mrpc', 'qqp', 'qnli', 'rte', 'wnli', 'mnli']:
        total = []
        for pred, label in zip(preds, labels):
            _, pred = pred.max(1)
            result = (pred == label).float()  # .cpu().numpy()
            total.append(result)
        acc = torch.cat(total).mean().item()
        return acc
    elif task in ['cola']:
        tn, tp, fn, fp = 0, 0, 0, 0
        for pred, label in zip(preds, labels):
            _, pred = pred.max(1)
            pred = pred.cpu().numpy()
            label = label.cpu().numpy()
            for p, l in zip(pred, label):
                if p == 1 and l == 1:
                    tp += 1
                elif p == 1 and l == 0:
                    fn += 1
                elif p == 0 and l == 1:
                    tn += 1
                elif p == 0 and l == 0:
                    fp += 1
        print(tn, tp, fn, fp)
        mcc = tp * tn - fp * fn
        den = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
        if den == 0:
            den = 1
        mcc = mcc / (den ** 0.5)
        return mcc

    elif task in ['stsb']:
        x, y = list(), list()
        for pred, label in zip(preds, labels):
            pred = pred.cpu().numpy()
            label = label.cpu().numpy()
            x.extend(pred.tolist())
            y.extend(label.tolist())
        x = [e[0] for e in x]
        # print(len(x), x[:10])
        # print(len(y), y[:10])
        score, _ = stats.spearmanr(x, y)
        return score

if __name__ == '__main__':

    fire.Fire(main)
