# -*- coding: utf-8 -*-
import argparse
import json
import os
import random

import numpy as np
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import AutoConfig
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorWithPadding
from datasets import Dataset
from sklearn.metrics import f1_score

from data_utils import get_all_data
from sentence_model import SentenceModel
from abduction_model import SentenceAbduction
from logger import print_log, ABLLogger
import warnings
warnings.filterwarnings("ignore")

def get_labels():
    tag_path = "data/tags.txt"
    with open(tag_path, 'r', encoding='utf-8') as fin:
        labels = [label.strip() for label in fin]
    label2index = {l:i for i, l in enumerate(labels)}
    return label2index

label2id = get_labels()
id2label = {v : k for k, v in label2id.items()}

def trans_label(label):
    vec = [0.0] * len(label2id.keys())
    for l in label:
        vec[label2id[l]] = 1.0
    return np.array(vec)

def trans_vec(vec):
    labels = []
    for i, val in enumerate(vec):
        if val == 1:
            labels.append(id2label[i])
    return labels

def to_hf_dataset(
    data,
    train_ahs,
    train_label_key,
    pretrain_ahs=None,
    pretrain_label_key=None,
    text_key="summary",
):
    if train_ahs is None:
        train_ahs = set([datapoint["ah"] for datapoint in data])
    text, labels = [], []
    for datapoint in data:
        if datapoint["ah"] in train_ahs:
            text.append(datapoint[text_key])
            labels.append([float(v) for v in datapoint[train_label_key]])
        elif pretrain_ahs and datapoint["ah"] in pretrain_ahs:
            text.append(datapoint[text_key])
            labels.append([float(v) for v in datapoint[pretrain_label_key]])
    assert len(text) == len(labels)
    return Dataset.from_dict({"text": text, "labels": labels})

def get_model(model_path):
    cfg = AutoConfig.from_pretrained(
        model_path,
        num_labels=8,
        problem_type="multi_label_classification",
        id2label=id2label,
        label2id=label2id,
    )
    return AutoModelForSequenceClassification.from_pretrained(
        model_path,
        config=cfg,
        device_map="cuda"
    )

def get_training_arguments(output_dir, lr, epoch):
    return TrainingArguments(
        output_dir=output_dir,
        learning_rate=lr,
        per_device_train_batch_size=32,
        num_train_epochs=epoch,
        weight_decay=0.01,
        save_strategy="epoch",
        save_total_limit=1,
        # logging_strategy="steps",
        # logging_steps=10,
        logging_strategy="no",
        # warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        # lr_scheduler_type="constant",
        report_to=None,
        # disable_tqdm=True,
    )

def pretrain(args, model_path, data, pretrain_ahs):
    def preprocess(exa):
        return tokenizer(exa["text"], truncation=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    pre_data_hf = to_hf_dataset(data, pretrain_ahs, "attr").map(preprocess, batched=True, desc="")
    model = get_model(model_path)
    ckpt_dir = "ckpt"
    training_args = get_training_arguments(ckpt_dir, args.pre_lr, args.pre_epoch)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=pre_data_hf,
        processing_class=tokenizer,
        data_collator=data_collator,
    )
    trainer.train()
    return get_last_model(ckpt_dir)

def predict(args, model_path, data, current_level_ahs):
    def preprocess(exa):
        return tokenizer(exa["text"], truncation=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    abl_data_hf = to_hf_dataset(data, current_level_ahs, "attr").map(preprocess, batched=True, desc="")
    model = get_model(model_path)
    training_args = get_training_arguments(None, args.lr, args.epoch)
    trainer = Trainer(
        model=model,
        args=training_args,
        processing_class=tokenizer,
        data_collator=data_collator
    )
    output = trainer.predict(abl_data_hf).predictions  # B x 8
    i = 0
    for datapoint in data:
        if current_level_ahs and datapoint["ah"] not in current_level_ahs:
            continue
        logit = output[i, :]
        datapoint["pred_prob"] = torch.sigmoid(torch.tensor(logit)).numpy()
        datapoint["pred_attr"] = np.where(logit > 0, 1, 0).astype(int).tolist()
        i += 1

def cal_num_tokens(data, model_path, pretrain_ahs, abduced_ahs):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    num = 0
    ahs = pretrain_ahs.union(abduced_ahs)
    for datapoint in data:
        if datapoint["ah"] in ahs:
            tokens = tokenizer(datapoint["summary"], truncation=True)["input_ids"]
            num += len(tokens)
    return num

def train(
    args,
    ti,
    model_path,
    data,
    pretrain_ahs,
    current_level_ahs_filtered,
):
    def preprocess(exa):
        return tokenizer(exa["text"], truncation=True)

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    abl_data_hf = to_hf_dataset(data, current_level_ahs_filtered, "abduced_attr", pretrain_ahs, "attr").map(preprocess, batched=True, desc="")

    model = get_model(model_path)
    ckpt_dir = f"ckpt_{ti}"
    training_args = get_training_arguments(ckpt_dir, args.lr, args.epoch)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=abl_data_hf,
        processing_class=tokenizer,
        data_collator=data_collator,
    )
    trainer.train()
    return get_last_model(ckpt_dir)


def train_sentence_model(sentence, data, ahs, train_ahs=None):
    money, attrs, labels = [], [], []
    if ahs is None:
        ahs = set([dp["ah"] for dp in data])
    for datapoint in data:
        if datapoint["ah"] in ahs:
            money.append([datapoint["money"]])
            labels.append(datapoint["month"])
            attrs.append(datapoint["attr"])
        elif train_ahs and datapoint["ah"] in train_ahs:
            money.append([datapoint["money"]])
            labels.append(datapoint["month"])
            attrs.append(datapoint["pred_attr"])
    sentence.fit(money, attrs, labels, 3)
    sentence.show_param()


def eval_predict(args, sentence, test_data, ahs, title="Test"):
    print_log(f"----- {title} -----", logger="current")
    moneys, pred_attrs, months, attrs = [], [], [], []
    if ahs is None:
        ahs = set([dp["ah"] for dp in test_data])
    for datapoint in test_data:
        if datapoint["ah"] not in ahs:
            continue
        months.append(datapoint["month"])
        moneys.append([datapoint["money"]])
        pred_attrs.append(datapoint["pred_attr"])
        attrs.append(datapoint["attr"])
    pred, gt = np.array(pred_attrs, dtype=int), np.array(attrs, dtype=int)
    total_f1 = f1_score(gt.flatten(), pred.flatten())
    result = {}
    print_log(f"micro f1: {total_f1:.4f}", logger="current")
    result["f1"] = total_f1
    for i in range(gt.shape[-1]):
        if1 = f1_score(gt[:, i], pred[:, i])
        print_log(f"{id2label[i]} f1: {if1:.4f}", logger="current")
        result[f"{id2label[i]} f1"] = if1
    mae, mse, _, _ = sentence.test(moneys, pred_attrs, months)
    result["mae"] = mae
    result["mse"] = mse
    print_log("----------------", logger="current")
    return result

def eval_abduction(train_data, abduced_ahs):
    print_log("----- Abdu -----", logger="current")
    abduced_attrs, attrs = [], []
    for datapoint in train_data:
        if datapoint["ah"] not in abduced_ahs:
            continue
        abduced_attrs.append(datapoint["abduced_attr"])
        attrs.append(datapoint["attr"])
    pred, gt = np.array(abduced_attrs, dtype=int), np.array(attrs, dtype=int)
    total_f1 = f1_score(gt.flatten(), pred.flatten())
    print_log(f"micro f1: {total_f1:.4f}", logger="current")
    for i in range(gt.shape[-1]):
        if1 = f1_score(gt[:, i], pred[:, i])
        print_log(f"{id2label[i]}: {if1:.4f}", logger="current")
    print_log("----------------", logger="current")

def get_last_model(path):
    ckpts = os.listdir(path)
    last_model = ""
    max_mid = 0
    for ckpt in ckpts:
        if not ckpt.startswith("checkpoint-"):
            continue
        mid = int(ckpt.split("-")[-1])
        if mid > max_mid:
            max_mid = mid
            last_model = ckpt
    return os.path.join(path, last_model)

def main(args):
    pretrain_lvl = args.pretrain_lvl
    train_data, levels_dict, test_data = get_all_data(pretrain_lvl)
    all_ahs = [datapoint["ah"] for datapoint in train_data]

    sentence = SentenceModel()
    abductor = SentenceAbduction(
        sentence,
        args.rule_filename,
        args.word_match,
        args.strong_conf,
    )

    model_path = args.model_path
    pretrain_ahs = levels_dict[pretrain_lvl]
    train_ahs = set([ah for ah in all_ahs if ah not in pretrain_ahs])
    print(len(pretrain_ahs), len(train_ahs))
    # pretrain bert
    model_path = pretrain(args, model_path, train_data, pretrain_ahs)
    # pretrain sentence model
    train_sentence_model(sentence, train_data, pretrain_ahs)
    # eval pretrain
    predict(args, model_path, train_data, pretrain_ahs)
    eval_predict(args, sentence, train_data, pretrain_ahs, "Eval")
    # test pretrained model
    predict(args, model_path, test_data, None)
    init_test_res = eval_predict(args, sentence, test_data, None)
    total_tokens = 0
    init_test_res["tokens"] = total_tokens

    max_revision = args.max_revision
    loops = args.loops
    records = [init_test_res]

    lvl = 9 # use all data
    print_log(f"===== Level {lvl} Num samples {len(train_ahs)} =====", logger="current")
    for l in range(loops):
        print_log(f"----- [Level|loop] [{lvl}|{l+1}] -----", logger="current")
        # predict on train data
        predict(args, model_path, train_data, train_ahs)
        # abduction
        abductor.set_predict_model(sentence)
        abduced_ahs = abductor.abduce_data(train_data, train_ahs, max_revision, lvl)
        eval_abduction(train_data, abduced_ahs)
        # train
        model_path = train(args, lvl, model_path, train_data, pretrain_ahs, abduced_ahs)
        consumed_tokens = cal_num_tokens(train_data, model_path, pretrain_ahs, abduced_ahs) * args.epoch
        total_tokens += consumed_tokens

        # predict on train data
        predict(args, model_path, train_data, train_ahs)
        train_sentence_model(sentence, train_data, pretrain_ahs, train_ahs)
        eval_predict(args, sentence, train_data, train_ahs, "Eval")
        # test checkpoint
        predict(args, model_path, test_data, None)
        test_res = eval_predict(args, sentence, test_data, None)
        test_res["tokens"] = total_tokens
        records.append(test_res)

    logger = ABLLogger.get_current_instance()
    with open(os.path.join(logger.log_dir, "res.json"), "w") as f:
        json.dump(records, f)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", default="google-bert/bert-base-chinese")
    parser.add_argument("--rule_filename", default="rule_file.txt")

    parser.add_argument("--pretrain_lvl", type=int, default=1)
    parser.add_argument("--loops", type=int, default=12, help="Loops")

    parser.add_argument("--pre_epoch", type=int, default=15)
    parser.add_argument("--pre_lr", type=float, default=5e-5)

    parser.add_argument("--epoch", type=int, default=3, help="Epochs per train")
    parser.add_argument("--lr", type=float, default=5e-5)

    parser.add_argument("--word_match", action="store_true")
    parser.add_argument("--max_revision", type=int, default=3)
    parser.add_argument("--strong_conf", action="store_true")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    args.word_match = True
    return args

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

if __name__ == '__main__':
    args = get_args()
    set_seed(args.seed)
    print_log(str(args), logger="current")
    main(args)