# THIS IS THE SOURCE CODE FOR ICLR 2024 SUBMISSION
# Unmasking and Improving Data Credibility: A Study with Datasets for Language Model Safety
#

import os
import functools
import yaml
import tqdm
import argparse
import logging

import numpy as np
import pandas as pd
import torch
import torcheval
from torcheval.metrics.functional import binary_auroc, binary_f1_score, binary_accuracy, multiclass_accuracy
from torcheval.metrics import BinaryAUROC
import pdb
import datasets
import models
from constants import *
from torch.utils.data import WeightedRandomSampler
import random
from torch.nn.parallel import DataParallel

def parse_args():
    parser = argparse.ArgumentParser(prog="Benchmark")
    parser.add_argument('--config')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--tensorboard', action='store_true')
    return parser.parse_args()



def train(config, train_dataset, tfbd_writer):
    
    train_label_key = config['training']['train_label']
    device = config['device']

    # create model
    model_config = config.pop('model', None)
    model = models.TextualModel.from_config(model_config['name'], **model_config['args'])
    logging.info("model loaded...")
    model.to(device)
    # model.to('cuda')
    

    max_length = max(min((128 * 192) // config['training']['batch_size'], 512), 128)
    data_collator = functools.partial(model.data_collator, features=['text'], max_length=max_length)
    collate_fn = lambda batch: data_collator(torch.utils.data.default_collate(batch))

    n_gpu = torch.cuda.device_count()
    if n_gpu > 1:
        print("We have", n_gpu, "GPUs!")
        model = DataParallel(model, device_ids=list(range(n_gpu)))

    labels = train_dataset.dataset.clean_labels[train_dataset.indices]

    class_count = {label: 0 for label in set(labels)}
    for label in labels:
        class_count[label] += 1

    weights = [1.0 / class_count[label] for label in labels]
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    if config['balance']:
        dataloader = functools.partial(
            torch.utils.data.DataLoader, 
            batch_size=config['training']['batch_size'], 
            shuffle=False, 
            sampler=sampler,
            collate_fn=collate_fn,
            num_workers=1,
            pin_memory=True,
        )
    else:
        dataloader = functools.partial(
            torch.utils.data.DataLoader, 
            batch_size=config['training']['batch_size'], 
            shuffle=True, 
            # sampler=sampler,
            collate_fn=collate_fn,
            num_workers=1,
            pin_memory=True,
        )
    
    train_dataloader = dataloader(train_dataset)
    
    step = 0
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['training']['learning_rate'])

    for e in range(config['training']['epoch']):
        # training
        acc = 0.0
        f1 = 0.0
        n_batch = 0.0
        for i, batch in enumerate(tqdm.tqdm(train_dataloader, desc="training batch...")):
            optimizer.zero_grad()
            if n_gpu > 1:
                inputs = {k: v for k, v in batch['text'].items()}
                labels = batch[train_label_key].long()
            else:
                inputs = {k: v.to(device) for k, v in batch['text'].items()}
                labels = batch[train_label_key].long().to(device)

            output = model(**inputs, labels=labels)
            loss = output.loss
            if n_gpu > 1:
                loss = loss.mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            step += 1

            pred = torch.nn.functional.softmax(output.logits, dim=1)
            if pred.shape[1] == 2:
                pred = pred[:, 1]
                acc += binary_accuracy(pred, labels.to(device)).item()
                f1 += binary_f1_score(pred, labels.to(device)).item()
            else:
                acc += multiclass_accuracy(pred, labels.to(device)).item()
            n_batch += 1
            

            if i % config['logging']['print_every'] == config['logging']['print_every'] - 1:
                if f1 > 0:
                    print(f"Epoch: {e:04d}, Step: {i:04d}, Loss: {loss.item():.4f}, ACC: {acc/n_batch}, F1: {f1/n_batch}", flush=True)
                else:
                    print(f"Epoch: {e:04d}, Step: {i:04d}, Loss: {loss.item():.4f}, ACC: {acc/n_batch}", flush=True)
                acc = 0.0
                n_batch = 0.0
                f1 = 0.0

                if args.tensorboard:
                    tfbd_writer.add_scalars('Loss/train', {model_config['name']: loss.item()}, step)
                    tfbd_writer.add_scalars('Accuracy/train', {model_config['name']: acc}, step)
    save_path = config['save_dir'] + config['logging']["log_name"][:-4] + '.pt'
    torch.save(model, save_path)
    print(f"model is saved to {save_path}")
    return model, tfbd_writer


def test(config, model, test_dataset, step, e):

    test_label_key = config['testing']['test_label']
    device = config['device']

    model_config = config.pop('model', None)
    max_length = max(min((128 * 192) // config['training']['batch_size'], 512), 128)
    data_collator = functools.partial(model.data_collator, features=['text'], max_length=max_length)
    collate_fn = lambda batch: data_collator(torch.utils.data.default_collate(batch))
    dataloader = functools.partial(
        torch.utils.data.DataLoader, 
        batch_size=config['training']['batch_size'], 
        shuffle=False, 
        collate_fn=collate_fn
    )

    test_dataloader  = dataloader(test_dataset)

    
    # testing
    metrics = {m: getattr(torcheval.metrics, m)(device=device) for m in config["testing"]['metrics']}
    test_loss = 0.
    total_num = 0
    for i, batch in enumerate(tqdm.tqdm(test_dataloader, desc="testing batch...")):

        with torch.no_grad():
            inputs = {k: v.to(device) for k, v in batch['text'].items()}
            labels = batch[test_label_key].long().to(device)
            output = model(**inputs, labels=labels)
            pred = torch.nn.functional.softmax(output.logits, dim=1)
            if pred.shape[1] == 2:
                pred = pred[:, 1]

            for m in metrics.keys():
                metrics[m].update(pred, labels)

        test_loss += output.loss.item()
        total_num += batch[test_label_key].shape[0]

    if args.tensorboard:
        tfbd_writer.add_scalars('Loss/test', {model_config['name']: test_loss / total_num}, step)
        for m, metric in metrics.items():
            tfbd_writer.add_scalars(f'{m}/test', {model_config['name']: metric.compute().item()}, step)
            

    result = "\t".join([f"{k}: {v.compute().item():.4f}" for k, v in metrics.items()])
    logging.info(f"Epoch: {e:04d}\t" + result)
    


if __name__ == "__main__":

    args = parse_args()

    with open(args.config, mode='rt', encoding='utf-8') as f:
        config = yaml.safe_load(f)
        print(config)
    if args.test:
        logging.basicConfig(filename=f"logs/run-2/{config['testing']['log_name']}", format='%(asctime)s - %(message)s', level=logging.INFO)
    else:
        logging.basicConfig(filename=f"logs/run-2/{config['logging']['log_name']}", format='%(asctime)s - %(message)s', level=logging.INFO)
    random_generator = torch.Generator().manual_seed(config['training']['seed'])

    config['device'] = torch.device(config['device'])

    if args.tensorboard:
        from torch.utils.tensorboard import SummaryWriter
        tfbd_writer =  SummaryWriter(
            log_dir=os.path.join(
                config['logging']['log_dir'], 
                config['dataset']['name'], 
                config['logging']['run_name'],
                f"train_{config['training']['train_label']}"
            ), 
            filename_suffix=config['run']
        )
    else:
        tfbd_writer = None

    # load the dataset
    data_config = config.pop('dataset', None)
    dataset = datasets.load_dataset_from_config(data_config['name'], **data_config['args'])
    if max(dataset.clean_labels) > 1:
        dataset.clean_labels[dataset.clean_labels <= 1.0] = 0
        dataset.clean_labels[dataset.clean_labels > 1.0] = 1
        dataset.raw_labels[dataset.raw_labels <= 1.0] = 0
        dataset.raw_labels[dataset.raw_labels > 1.0] = 1
        import collections
        print(f'make it binary: \n Clean: {collections.Counter(dataset.clean_labels)}\nRaw: {collections.Counter(dataset.raw_labels)}')

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size], random_generator)

    if config['training']['train_label'] == FILTER_LABEL:
        datasets.drop_disagreement(train_dataset)
        config['training']['train_label'] = CLEAN_LABEL
        print('The size of filtered dataset is ', len(train_dataset))

    if config['testing']['test_label'] == FILTER_LABEL:
        datasets.drop_disagreement(test_dataset)
        config['testing']['test_label'] = CLEAN_LABEL
        print('The size of filtered dataset is ', len(test_dataset))

    if config['testing']['test_label'] == BEST_LABEL
        Jigsaw_idx = 0
        chatGPT_rec = torch.load(f"Jigsaw_test_thre07_chatGPT_test_rec.pt")
        chatGPT_idx = {}
        curated_label = {}
        chatGPT_feature = {}
        for i in chatGPT_rec:
            if i["chatGPT_out"].strip('[ChatGPT]: ').startswith("B"):
                _tmp = 0
            else:
                _tmp = 1
            chatGPT_idx[i["index"]] = _tmp
            curated_label[i["index"]] = i["label"]
            chatGPT_feature[i["index"]] = i["feature"]
        cnt_agree = 0
        cnt_all = 0
        for i in test_dataset.dataset.raw_idx[test_dataset.indices]:
            if i in chatGPT_idx:
                if test_dataset.dataset.clean_labels[i] != test_dataset.dataset.raw_labels[i]:
                    if chatGPT_idx[i] == test_dataset.dataset.clean_labels[i]:
                        cnt_agree += 1
                    test_dataset.dataset.clean_labels[i] = chatGPT_idx[i]
                    cnt_all += 1

        print(f"ratio of agreements: {cnt_agree}/{cnt_all} = {cnt_agree/cnt_all}")
        config['testing']['test_label'] = CLEAN_LABEL



    if args.test:
        model_path = config['save_dir'] + config['logging']["log_name"][:-4] + '.pt'
        model = torch.load(model_path)

        if isinstance(model, DataParallel):
            model = model.module
        test(config, model, test_dataset, step=-1, e=-1)
    else:
        model, tfbd_writer = train(config, train_dataset, tfbd_writer)