import os
import json
import torch
import random
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
import time
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from collections import defaultdict
import torch
import random
from tqdm import tqdm
import numpy as np
from model import PretrainModel
from util import seed_torch, get_dataset
import scipy.stats
import argparse
from sklearn.metrics import matthews_corrcoef


def train(model,
          dataloader,
          optimizer,
          epochs,
          model_save_path,
          criterion,
          info_dict=None,
          loss_acc_dict=None,
          scheduler=None,
          dev_dataloader=None,
          save_steps=-1,
          track_info_per_iter=-1,
          track_info=False,
          verbose=True):
    model.train()
    iteration = 0

    for epoch in range(1, epochs + 1):

        for idx, (token, mask, token_type_ids, label) in enumerate(dataloader):
            iteration += 1

            optimizer.zero_grad()
            start_time = time.time()

            predicted_label = model(input_ids=token.to(device), attention_mask=mask.to(device),
                                    token_type_ids=token_type_ids.to(device))
            if dataset_name == 'stsb':
                predicted_label = torch.squeeze(predicted_label, dim=-1)

            loss = criterion(predicted_label, label.to(device))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

            optimizer.step()

            if scheduler:
                scheduler.step()

            y_pred = predicted_label.cpu().argmax(1).detach().numpy()
            y_true = label.cpu().detach().numpy()

            loss_value = loss.data

            acc = accuracy_score(y_true, y_pred)
            p_value = precision_score(y_true, y_pred, average='macro')
            r_value = recall_score(y_true, y_pred, average='macro')
            f1 = (2 * p_value * r_value) / (p_value + r_value)

            cost_time = time.time() - start_time

            if iteration % track_info_per_iter == 0 and track_info_per_iter > 0 and track_info:
                loss_acc_dict["tr_loss"].append(loss_value)
                loss_acc_dict["tr_acc"].append(acc)

                info = model.compute_information_bp_fast(dataloader)
                for k in info.keys():
                    info_dict[k].append(info[k])

                print("iteration/epoch: {}/{}, info: {}".format(iteration, epoch, info))

                if dev_dataloader:
                    dev_acc, dev_loss, dev_f1 = evaluate_dev(model, test_dataloader)
                    loss_acc_dict["va_acc"].append(dev_acc)
                    loss_acc_dict["va_loss"].append(dev_loss)
                    loss_acc_dict["va_f1"].append(dev_f1)

            if save_steps > 0 and iteration % save_steps == 0:
                if not os.path.exists(model_save_path):
                    os.makedirs(model_save_path)
                torch.save(model, os.path.join(model_save_path, "{}_iteration_{}.pkl".format(model_name, iteration)))

            if verbose:
                print(
                    '| epoch:{:3d} | batch:{:4d}/{:4d} | lr:{:10.9f} |train_loss:{:7.2f} | train_acc:{:7.3f} |train_f1:{:7.3f}|train_P:{:7.3f}|train_R:{:7.3f}| time: {:4.2f}s'.format(
                        epoch, idx, len(dataloader), optimizer.param_groups[0]['lr'], loss_value, acc, f1, p_value,
                        r_value, cost_time))

        if save_steps == -1 and epoch == epochs:
            if not os.path.exists(model_save_path):
                os.makedirs(model_save_path)
            torch.save(model, os.path.join(model_save_path, "{}.pkl".format(model_name)))

        # track info every epoch
        if track_info_per_iter == -1 and track_info:
            info = model.compute_information_bp_fast(dataloader, criterion)
            for k in info.keys():
                info_dict[k].append(info[k])
            if verbose:
                print("epoch: {}, info: {}".format(epoch, info))

            l2_norm = 0
            for pa in model.named_parameters():
                l2_norm += pa[1].data.norm(2)
            loss_acc_dict["l2_norm"].append(l2_norm.cpu().item())

            if dev_dataloader:
                dev_acc, dev_loss, dev_f1 = evaluate_dev(model, test_dataloader)
                tr_acc, tr_loss, tr_f1 = evaluate_dev(model, train_dataloader)

                loss_acc_dict["va_acc"].append(dev_acc)
                loss_acc_dict["va_loss"].append(dev_loss)
                loss_acc_dict["va_f1"].append(dev_f1)

                loss_acc_dict["tr_acc"].append(tr_acc)
                loss_acc_dict["tr_loss"].append(tr_loss)
                loss_acc_dict["tr_f1"].append(tr_f1)

    return info_dict, loss_acc_dict


best_eval_acc = 0


def evaluate_dev(model, dataloader):
    model.eval()
    y_true_list = []
    y_pred_list = []
    loss_list = []

    for idx, (token, mask, token_type_ids, label) in tqdm(enumerate(dataloader)):
        predicted_label = model(token.to(device), mask.to(device), token_type_ids.to(device))

        y_pred = predicted_label.cpu().argmax(1).detach().numpy()
        y_true = label.cpu().numpy()

        loss = criterion(torch.squeeze(predicted_label, dim=-1), label.to(device))

        y_true_list.extend(y_true.tolist())
        y_pred_list.extend(y_pred.tolist())
        loss_list.append(loss.item())

    acc = accuracy_score(y_true_list, y_pred_list)
    p_value = precision_score(y_true_list, y_pred_list, average='macro')
    r_value = recall_score(y_true_list, y_pred_list, average='macro')
    f1 = (2 * p_value * r_value) / (p_value + r_value)

    global best_eval_acc
    best_eval_acc = max(best_eval_acc, acc)

    print('-' * 100)
    print("| dev_acc:{:6.4f} | dev_f1:{:6.4f} | dev_P:{:6.4f} | dev_R:{:6.4f} | best_acc:{:6.4f}".format(
        acc, f1, p_value, r_value, best_eval_acc))
    print('-' * 100)

    print(classification_report(y_true=y_true_list, y_pred=y_pred_list))
    print(confusion_matrix(y_true_list, y_pred_list))

    return acc, np.mean(loss_list), f1


def get_parse():
    parser = argparse.ArgumentParser(description='eval iiw')

    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--warm_up_proportion', type=float, default=0.1)
    parser.add_argument('--prior_model_epoch', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=15)
    parser.add_argument('--pooling_type', type=str, default='first_last_avg')
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--model_state', type=str, choices=['pretrain', 'random'])
    parser.add_argument('--lr', type=float, default=0.0001)
    parser.add_argument('--scheduler_type', type=str, default='warmup', choices=['warmup', 'constant'])
    parser.add_argument('--dataset_name', type=str,
                        choices=['sst2', 'mrpc', 'qqp', 'mnli', 'mnli_mismatched', 'qnli', 'rte',
                                 'wnli'])
    parser.add_argument('--seed', type=int, default=2024)

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = get_parse()
    seed_torch(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = args.batch_size
    warm_up_proportion = args.warm_up_proportion
    prior_model_epoch = args.prior_model_epoch
    epochs = args.epochs
    pooling_type = args.pooling_type
    model_name = args.model_name
    model_path = args.model_path
    model_state = args.model_state
    random_init = True if model_state == 'random' else False
    lr = args.lr
    scheduler_type = args.scheduler_type
    dataset_name = args.dataset_name

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    train_data, dev_data, test_data, num_class = get_dataset(dataset_name, tokenizer)
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, )
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, )
    dev_dataloader = DataLoader(dev_data, batch_size=batch_size, shuffle=True, )

    if dataset_name == 'stsb':
        criterion = torch.nn.MSELoss()
    else:
        criterion = torch.nn.CrossEntropyLoss()

    prior_model = PretrainModel(model_path, num_class=num_class, random_init=random_init, pooling_type=pooling_type,
                                prior_model=None).to(device)
    prior_model_optimizer = torch.optim.AdamW(prior_model.parameters(), lr=lr)
    prior_model_total_step = len(dev_data) * prior_model_epoch // batch_size

    if scheduler_type == 'warmup':
        prior_model_scheduler = get_linear_schedule_with_warmup(prior_model_optimizer,
                                                                num_warmup_steps=prior_model_total_step * warm_up_proportion,
                                                                num_training_steps=prior_model_total_step)
    elif scheduler_type == 'constant':
        prior_model_scheduler = None

    train(prior_model,
          dataloader=dev_dataloader,
          optimizer=prior_model_optimizer,
          epochs=prior_model_epoch,
          scheduler=prior_model_scheduler,
          model_save_path='model/{}_{}_{}_{}_{}/prior_model'.format(model_name, dataset_name, pooling_type, model_state,
                                                                    scheduler_type),
          criterion=criterion,
          info_dict=None,
          loss_acc_dict=None,
          dev_dataloader=None,
          save_steps=-1,
          track_info_per_iter=-1,
          track_info=False,
          verbose=True)

    prior_model_path = 'model/{}_{}_{}_{}_{}/prior_model/{}.pkl'.format(model_name, dataset_name, pooling_type,
                                                                        model_state, scheduler_type, model_name)
    prior_model = torch.load(prior_model_path)
    model = PretrainModel(model_path, num_class=num_class, random_init=random_init, pooling_type=pooling_type,
                          prior_model=prior_model.to(device)).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    total_step = len(train_data) * epochs // batch_size

    if scheduler_type == 'warmup':
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=total_step * warm_up_proportion,
                                                    num_training_steps=total_step)
    elif scheduler_type == 'constant':
        scheduler = None

    info_dict = defaultdict(list)
    loss_acc_dict = defaultdict(list)

    dev_acc, dev_loss, dev_f1 = evaluate_dev(model, test_dataloader)
    tr_acc, tr_loss, tr_f1 = evaluate_dev(model, train_dataloader)

    loss_acc_dict["va_acc"].append(dev_acc)
    loss_acc_dict["va_loss"].append(dev_loss)
    loss_acc_dict["va_f1"].append(dev_f1)

    loss_acc_dict["tr_acc"].append(tr_acc)
    loss_acc_dict["tr_loss"].append(tr_loss)
    loss_acc_dict["tr_f1"].append(tr_f1)

    info = model.compute_information_bp_fast(train_dataloader, criterion)
    for k in info.keys():
        info_dict[k].append(info[k])
    print(info_dict)
    print('finish computing initial iiw...')

    train(model,
          dataloader=train_dataloader,
          optimizer=optimizer,
          epochs=epochs,
          scheduler=scheduler,
          model_save_path='model/{}_{}_{}_{}_{}/train'.format(model_name, dataset_name, pooling_type, model_state,
                                                              scheduler_type),
          criterion=criterion,
          info_dict=info_dict,
          loss_acc_dict=loss_acc_dict,
          dev_dataloader=test_dataloader,
          save_steps=-1,
          track_info_per_iter=-1,
          track_info=True,
          verbose=True)

    with open('model/{}_{}_{}_{}_{}/train/iiw.jsonl'.format(model_name, dataset_name, pooling_type, model_state,
                                                            scheduler_type), 'w') as json_file:
        json.dump(info_dict, json_file)

    with open('model/{}_{}_{}_{}_{}/train/loss_acc.jsonl'.format(model_name, dataset_name, pooling_type, model_state,
                                                                 scheduler_type), 'w') as json_file:
        json.dump(loss_acc_dict, json_file)
