# -*- coding: utf-8 -*-
import argparse
import json
import os
import random

import numpy as np
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import AutoConfig
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorWithPadding
from datasets import Dataset
from sklearn.metrics import f1_score

from data_utils import get_all_data
from sentence_model import SentenceModel
from abduction_model import SentenceAbduction
from logger import print_log, ABLLogger
import warnings
warnings.filterwarnings("ignore")

def get_labels():
    tag_path = "data/tags.txt"
    with open(tag_path, 'r', encoding='utf-8') as fin:
        labels = [label.strip() for label in fin]
    label2index = {l:i for i, l in enumerate(labels)}
    return label2index

label2id = get_labels()
id2label = {v : k for k, v in label2id.items()}

def trans_label(label):
    vec = [0.0] * len(label2id.keys())
    for l in label:
        vec[label2id[l]] = 1.0
    return np.array(vec)

def trans_vec(vec):
    labels = []
    for i, val in enumerate(vec):
        if val == 1:
            labels.append(id2label[i])
    return labels

def to_hf_dataset(
    data,
    train_ahs,
    train_label_key,
    pretrain_ahs=None,
    pretrain_label_key=None,
    text_key="summary",
):
    if train_ahs is None:
        train_ahs = set([datapoint["ah"] for datapoint in data])
    text, labels = [], []
    for datapoint in data:
        if datapoint["ah"] in train_ahs:
            text.append(datapoint[text_key])
            labels.append([float(v) for v in datapoint[train_label_key]])
        elif pretrain_ahs and datapoint["ah"] in pretrain_ahs:
            text.append(datapoint[text_key])
            labels.append([float(v) for v in datapoint[pretrain_label_key]])
    assert len(text) == len(labels)
    return Dataset.from_dict({"text": text, "labels": labels})

def get_model(model_path):
    cfg = AutoConfig.from_pretrained(
        model_path,
        num_labels=8,
        problem_type="multi_label_classification",
        id2label=id2label,
        label2id=label2id,
    )
    return AutoModelForSequenceClassification.from_pretrained(
        model_path,
        config=cfg,
        device_map="cuda"
    )

def get_training_arguments(output_dir, lr, epoch):
    return TrainingArguments(
        output_dir=output_dir,
        learning_rate=lr,
        per_device_train_batch_size=32,
        num_train_epochs=epoch,
        weight_decay=0.01,
        save_strategy="epoch",
        save_total_limit=1,
        # logging_strategy="steps",
        # logging_steps=10,
        logging_strategy="no",
        # warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        # lr_scheduler_type="constant",
        report_to=None,
        # disable_tqdm=True,
    )

def pretrain(args, model_path, data, pretrain_ahs):
    def preprocess(exa):
        return tokenizer(exa["text"], truncation=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    pre_data_hf = to_hf_dataset(data, pretrain_ahs, "attr").map(preprocess, batched=True, desc="")
    model = get_model(model_path)
    ckpt_dir = "ckpt"
    training_args = get_training_arguments(ckpt_dir, args.pre_lr, args.pre_epoch)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=pre_data_hf,
        processing_class=tokenizer,
        data_collator=data_collator,
    )
    trainer.train()
    return get_last_model(ckpt_dir)

def predict(args, model_path, data, current_level_ahs):
    def preprocess(exa):
        return tokenizer(exa["text"], truncation=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    abl_data_hf = to_hf_dataset(data, current_level_ahs, "attr").map(preprocess, batched=True, desc="")
    model = get_model(model_path)
    training_args = get_training_arguments(None, args.lr, args.epoch)
    trainer = Trainer(
        model=model,
        args=training_args,
        processing_class=tokenizer,
        data_collator=data_collator
    )
    output = trainer.predict(abl_data_hf).predictions  # B x 8
    i = 0
    for datapoint in data:
        if current_level_ahs and datapoint["ah"] not in current_level_ahs:
            continue
        logit = output[i, :]
        datapoint["pred_prob"] = torch.sigmoid(torch.tensor(logit)).numpy()
        datapoint["pred_attr"] = np.where(logit > 0, 1, 0).astype(int).tolist()
        i += 1

def cal_num_tokens(data, model_path, pretrain_ahs, abduced_ahs):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    num = 0
    ahs = pretrain_ahs.union(abduced_ahs)
    for datapoint in data:
        if datapoint["ah"] in ahs:
            tokens = tokenizer(datapoint["summary"], truncation=True)["input_ids"]
            num += len(tokens)
    return num

def train(
    args,
    ti,
    model_path,
    data,
    pretrain_ahs,
    current_level_ahs_filtered,
):
    def preprocess(exa):
        return tokenizer(exa["text"], truncation=True)

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    abl_data_hf = to_hf_dataset(data, current_level_ahs_filtered, "abduced_attr", pretrain_ahs, "attr").map(preprocess, batched=True, desc="")

    model = get_model(model_path)
    ckpt_dir = f"ckpt_{ti}"
    training_args = get_training_arguments(ckpt_dir, args.lr, args.epoch)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=abl_data_hf,
        processing_class=tokenizer,
        data_collator=data_collator,
    )
    trainer.train()
    return get_last_model(ckpt_dir)


def train_sentence_model(sentence, data, ahs, train_ahs=None):
    money, attrs, labels = [], [], []
    if ahs is None:
        ahs = set([dp["ah"] for dp in data])
    for datapoint in data:
        if datapoint["ah"] in ahs:
            money.append([datapoint["money"]])
            labels.append(datapoint["month"])
            attrs.append(datapoint["attr"])
        elif train_ahs and datapoint["ah"] in train_ahs:
            money.append([datapoint["money"]])
            labels.append(datapoint["month"])
            attrs.append(datapoint["pred_attr"])
    sentence.fit(money, attrs, labels, 3)
    sentence.show_param()


def eval_predict(args, sentence, test_data, ahs, title="Test"):
    print_log(f"----- {title} -----", logger="current")
    moneys, pred_attrs, months, attrs = [], [], [], []
    if ahs is None:
        ahs = set([dp["ah"] for dp in test_data])
    for datapoint in test_data:
        if datapoint["ah"] not in ahs:
            continue
        months.append(datapoint["month"])
        moneys.append([datapoint["money"]])
        pred_attrs.append(datapoint["pred_attr"])
        attrs.append(datapoint["attr"])
    pred, gt = np.array(pred_attrs, dtype=int), np.array(attrs, dtype=int)
    total_f1 = f1_score(gt.flatten(), pred.flatten())
    result = {}
    print_log(f"micro f1: {total_f1:.4f}", logger="current")
    result["f1"] = total_f1
    for i in range(gt.shape[-1]):
        if1 = f1_score(gt[:, i], pred[:, i])
        print_log(f"{id2label[i]} f1: {if1:.4f}", logger="current")
        result[f"{id2label[i]} f1"] = if1
    mae, mse, _, _ = sentence.test(moneys, pred_attrs, months)
    result["mae"] = mae
    result["mse"] = mse
    print_log("----------------", logger="current")
    return result

def eval_abduction(train_data, abduced_ahs):
    print_log("----- Abdu -----", logger="current")
    abduced_attrs, attrs = [], []
    for datapoint in train_data:
        if datapoint["ah"] not in abduced_ahs:
            continue
        abduced_attrs.append(datapoint["abduced_attr"])
        attrs.append(datapoint["attr"])
    pred, gt = np.array(abduced_attrs, dtype=int), np.array(attrs, dtype=int)
    total_f1 = f1_score(gt.flatten(), pred.flatten())
    print_log(f"micro f1: {total_f1:.4f}", logger="current")
    for i in range(gt.shape[-1]):
        if1 = f1_score(gt[:, i], pred[:, i])
        print_log(f"{id2label[i]}: {if1:.4f}", logger="current")
    print_log("----------------", logger="current")

def get_last_model(path):
    ckpts = os.listdir(path)
    last_model = ""
    max_mid = 0
    for ckpt in ckpts:
        if not ckpt.startswith("checkpoint-"):
            continue
        mid = int(ckpt.split("-")[-1])
        if mid > max_mid:
            max_mid = mid
            last_model = ckpt
    return os.path.join(path, last_model)

def main(args):
    pretrain_lvl = args.pretrain_lvl
    train_data, levels_dict, test_data = get_all_data(pretrain_lvl)
    abl_lvls = list(sorted(levels_dict.keys()))
    abl_lvls.remove(pretrain_lvl)

    sentence = SentenceModel()
    abductor = SentenceAbduction(
        sentence,
        args.rule_filename,
        args.word_match,
        args.strong_conf,
    )

    model_path = args.model_path
    pretrain_ahs = levels_dict[pretrain_lvl]
    # pretrain bert
    model_path = pretrain(args, model_path, train_data, pretrain_ahs)
    # pretrain sentence model
    train_sentence_model(sentence, train_data, pretrain_ahs)
    # eval pretrain
    predict(args, model_path, train_data, pretrain_ahs)
    eval_predict(args, sentence, train_data, pretrain_ahs, "Eval")
    # test pretrained model
    predict(args, model_path, test_data, None)
    init_test_res = eval_predict(args, sentence, test_data, None)
    total_tokens = 0
    init_test_res["tokens"] = total_tokens

    max_revision = args.max_revision
    rounds = args.rounds
    min_loops = args.min_loops
    max_loops = args.max_loops
    next_level_f1_thresh = 0.7

    records = [init_test_res]


    for r in range(rounds):
        print_log(f"##### ROUND {r+1} #####", logger="current")
        for lvl in abl_lvls:
            train_ahs = levels_dict[lvl]
            print_log(f"===== Level {lvl} Num samples {len(train_ahs)} =====", logger="current")
            l = 0
            while True:
                print_log(f"----- [Level|loop] [{lvl}|{l+1}] -----", logger="current")
                # predict on train data
                predict(args, model_path, train_data, train_ahs)
                # abduction
                abductor.set_predict_model(sentence)
                abduced_ahs = abductor.abduce_data(train_data, train_ahs, max_revision, lvl)
                eval_abduction(train_data, abduced_ahs)
                # train
                model_path = train(args, lvl, model_path, train_data, pretrain_ahs, abduced_ahs)
                consumed_tokens = cal_num_tokens(train_data, model_path, pretrain_ahs, abduced_ahs) * args.epoch
                total_tokens += consumed_tokens
                # predict on train data
                predict(args, model_path, train_data, train_ahs)
                train_sentence_model(sentence, train_data, pretrain_ahs, train_ahs)
                eval_f1 = eval_predict(args, sentence, train_data, train_ahs, "Eval")["f1"]
                # test checkpoint
                predict(args, model_path, test_data, None)
                test_res = eval_predict(args, sentence, test_data, None)
                test_res["tokens"] = total_tokens
                records.append(test_res)
                l += 1
                if lvl != max(abl_lvls) and min_loops <= l < max_loops and eval_f1 >= next_level_f1_thresh:
                    print_log(f"Next Level after loop {l}.", logger="current")
                    break
                if l >= max_loops:
                    print_log(f"Max loops reached.", logger="current")
                    break
    logger = ABLLogger.get_current_instance()
    with open(os.path.join(logger.log_dir, "res.json"), "w") as f:
        json.dump(records, f)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", default="google-bert/bert-base-chinese")
    parser.add_argument("--rule_filename", default="rule_file.txt")

    parser.add_argument("--pretrain_lvl", type=int, default=1)
    parser.add_argument("--rounds", type=int, default=1, help="Rounds of multi-phase")
    parser.add_argument("--min_loops", type=int, default=1, help="Min loops per phase")
    parser.add_argument("--max_loops", type=int, default=5, help="Max loops per phase")

    # parser.add_argument("--no_pretrain", action="store_true")
    parser.add_argument("--pre_epoch", type=int, default=15)
    parser.add_argument("--pre_lr", type=float, default=5e-5)

    parser.add_argument("--epoch", type=int, default=3, help="Epochs per train")
    parser.add_argument("--lr", type=float, default=5e-5)

    parser.add_argument("--word_match", action="store_true")
    parser.add_argument("--max_revision", type=int, default=3)
    parser.add_argument("--strong_conf", action="store_true")
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    args.word_match = True
    return args

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

if __name__ == '__main__':
    args = get_args()
    set_seed(args.seed)
    print_log(str(args), logger="current")
    main(args)