from utils.trainer import run_pmixup, run_baseline, pmixup_evaluate
from pos_in_important import *
from pos_augmentation import *
import transformers
from models.tmix import *

transformers.logging.set_verbosity_error()
import logging

# logging.disable(logging.INFO) # disable INFO and DEBUG logging everywhere
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
from transformers.utils import logging

logging.set_verbosity(40)
nltk.download('wordnet')


def make_mini_sample(dataset, sample_size):
    sample_texts, sample_labels = [], []
    train_df = pd.read_csv(f"../dataset/{dataset}/train.csv")
    for key, values in Counter(train_df['label']).items():
        tmp = [train_df.iloc[i]['text'] for i in range(len(train_df)) if train_df.iloc[i]['label'] == key]
        try:
            random_texts = random.sample(tmp, sample_size)
        except:
            random_texts = tmp
        sample_texts += random_texts
        labs = [key for _ in range(sample_size)]
        sample_labels += labs
    new_df = pd.DataFrame({"text": sample_texts, "label": sample_labels})
    new_df.to_csv(f"../dataset/{dataset}/train_{sample_size}.csv")
    return new_df


def main(args):
    device = torch.device("cuda")
    seed_everything()
    datasets = args.datasets
    sample_sizes = args.sample_per_class

    for dataset in datasets:
        for sample_size in sample_sizes:
            df = pd.read_csv(f"../dataset/{dataset}/train.csv")

            val_dataframe = pd.read_csv(f"../dataset/{dataset}/test.csv")
            label_dict = get_label_dict(val_dataframe, 'label')

            if not os.path.exists(f"../model_weights/{dataset}/model_baseline.pt"):
                run_baseline(args, label_dict, df, dataset, val_dataframe, feature=None, condition=None)

            val_dataset = TextDataset(val_dataframe, label_dict, "text", "label", args.max_length)
            val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

            print(f'----- Making Dataset for {sample_size} Samples Per Class -------')
            if not os.path.exists(f"../dataset/{dataset}/train_{sample_size}.csv"):
                mini_df = make_mini_sample(dataset, sample_size)
            else:
                mini_df = pd.read_csv(f"../dataset/{dataset}/train_{sample_size}.csv")
            out_dir1 = f'../dataset/{dataset}/imp_removed_{sample_size}.csv'
            out_dir2 = f'../dataset/{dataset}/imp_list_{sample_size}.csv'

            imp_removed = make_important_tokens(mini_df, dataset, out_dir1, out_dir2)
            imp_tokens = pd.read_csv(out_dir2)['tokens'].tolist()

            run_baseline(args, label_dict, mini_df, dataset, val_dataframe, feature=None, condition=None)

            run_pmixup(args, label_dict, mini_df, dataset, val_dataframe, sample_size, feature="tmix")

            auged_df = important_augmentation(mini_df, imp_tokens)
            run_pmixup(args, label_dict, auged_df, dataset, val_dataframe, sample_size, feature="imp")

            for pos in args.pos:
                pos_aug_df = pos_augmentation(mini_df, pos)
                run_pmixup(args, label_dict, pos_aug_df, dataset, val_dataframe, sample_size, feature=pos)

    ## Evaluation ##
    if args.eval == True:
        for dataset in args.datasets:
            for pos in args.pos:
                for sample_size in args.sample_per_class:
                    val_dataframe = pd.read_csv(f"../dataset/{dataset}/test.csv")
                    label_dict = get_label_dict(val_dataframe, 'label')

                    model = MixText(num_labels=len(label_dict), mix_option=True).cuda()

                    pretrained_dict = torch.load(f'../model_weights/{dataset}/pmixup_model_{pos}_{sample_size}.pt')
                    partial_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}

                    model.load_state_dict(partial_dict)

                    val_dataset = TextDataset(val_dataframe, label_dict, "text", "label", args.max_length)
                    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
                    _, val_acc, _ = pmixup_evaluate(model, val_dataloader)
                    print(f"{dataset}_{pos}_{sample_size}_test_acc: {val_acc}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datasets", nargs="+", default=["banking"])
    parser.add_argument("--lr", default=1e-5)
    parser.add_argument('--batch_size', default=32)
    parser.add_argument('--max_length', default=100)
    parser.add_argument('--sample_per_class', nargs="+", default=[10, 200, 2500])
    parser.add_argument('--pos', nargs="+", default=["verb"])
    parser.add_argument('--model_name', default="bert-base-uncased")
    parser.add_argument('--num_epochs', default=100)
    parser.add_argument('--eval', default=True)
    args = parser.parse_args()
    main(args)
