import torch
from transformers import BertTokenizer
from datasets import load_dataset, DatasetDict
import numpy as np
import os
import argparse
import time
import math
from torch.utils.data import DataLoader
from methods import tfbo, adabio
import random
def random_seed(value):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    np.random.seed(value)
    random.seed(value)



def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--data", default='trec', type=str,
                        help="dataset for meta-learning", )

    parser.add_argument("--save_direct", default='trec/tuned', type=str,
                        help="Path to save file")

    parser.add_argument("--methods", default='adabio', type=str,
                        help="choise method [tfbo, adabio]")

    parser.add_argument("--num_labels", default=6, type=int,
                        help="Number of class for classification")

    parser.add_argument("--hidden_size", default=768, type=int,
                        help="hidden size of bert")

    parser.add_argument("--epochs", default=50, type=int,
                        help="Number of outer interation")

    parser.add_argument("--k_spt", default=25, type=int,
                        help="Number of support samples per task")

    parser.add_argument("--outer_batch_size", default=256, type=int,
                        help="Batch of task size")

    parser.add_argument("--inner_batch_size", default=256, type=int,
                        help="Training batch size in inner iteration")

    parser.add_argument("--outer_update_lr", default=1e-3, type=float,
                        help="Meta learning rate")

    parser.add_argument("--inner_update_lr", default=1e-3, type=float,
                        help="Inner update learning rate")

    parser.add_argument("--neumann_lr", default=1e-3, type=float,
                        help="update for hessian")

    parser.add_argument("--hessian_q", default=5, type=int,
                        help="Q steps for hessian-inverse-vector product")

    parser.add_argument("--gamma", default=1e-3, type=float,
                        help="clipping threshold")

    parser.add_argument("--seed", default=42, type=int,
                        help="random seed")

    parser.add_argument("--beta", default=0.90, type=float,
                        help="momentum parameters")

    parser.add_argument("--eta", default=0.1, type=float,
                        help="learning rate")
    
    parser.add_argument("--alpha", default=0.1, type=float,
                        help="hyperparameter for x in adabio")

    args = parser.parse_args()
    random_seed(args.seed)
    st = time.time()

    if args.methods == 'tfbo':
        args.inner_update_lr = 1e-1
        args.outer_update_lr = 1e-2
        args.hessian_q = 3
        learner = tfbo.Learner(args)

    if args.methods == 'adabio':
        args.inner_update_lr = 0.5
        args.outer_update_lr = 1e-5
        args.alpha = 1.0
        args.beta = 0.9
        args.gamma = 1e-1
        args.hessian_q = 3
        learner = adabio.Learner(args)

    # load data
    data = load_dataset("trec", trust_remote_code=True)
    # Split training set into train + validation
    train_valid = data["train"].train_test_split(test_size=0.2, seed=42)
    dataset = DatasetDict({
        "train": train_valid["train"],
        "validation": train_valid["test"],
        "test": data["test"]
    })
    # load tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # data preprocessing function
    def preprocess_data(examples):
        encodings = tokenizer(examples['text'], truncation=True, padding="max_length", max_length=128)
        encodings["labels"] = examples["coarse_label"]
        return encodings

    encoded_dataset = dataset.map(preprocess_data, batched=True)
    encoded_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    train_loader = DataLoader(encoded_dataset["train"], batch_size=64, shuffle=True)
    val_loader = DataLoader(encoded_dataset["validation"], batch_size=64, shuffle=True)
    test_loader = DataLoader(encoded_dataset["test"], batch_size=64)
    acc_test = []
    loss_test = []
    acc_train = []
    loss_train = []
    print(args)
    for epoch in range(args.epochs):
        print(f'Meta learning method: {args.methods}')
        print(f"[epoch/epochs]:{epoch}/{args.epochs}")
        print(f'\n---------Epoch: {epoch} -----------')
        acc, loss = learner(train_loader, val_loader)
        acc_train.append(float(format(np.mean(acc), '.4f')))
        loss_train.append(float(format(np.mean(loss), '.4f')))
        print(f'{args.methods}, Epoch: {epoch}, Train Loss: {loss_train}')
        print(f'{args.methods}, Epoch: {epoch}, Train Acc: {acc_train}')

        print("---------- Testing Mode -------------")
        loss, acc = learner.evaluate(test_loader)
        acc_test.append(float(format(acc, '.4f')))
        loss_test.append(float(format(loss, '.4f')))
        print(f'{args.methods}, Epoch: {epoch}, Test Loss: {loss_test}')
        print(f'{args.methods}, Epoch: {epoch}, Test Acc: {acc_test}')

    file_name = f'{args.methods}_outlr{args.outer_update_lr}_inlr{args.inner_update_lr}_seed{args.seed}_alpha{args.alpha}'
    args.save_direct = os.path.join(args.save_direct, args.methods)
    if not os.path.exists('logs/' + args.save_direct):
        os.mkdir('logs/' + args.save_direct)
    save_path = 'logs/' + args.save_direct
    total_time = (time.time() - st) / 3600
    files = open(os.path.join(save_path, file_name) + '.txt', 'w')
    files.write(str({'Exp configuration': str(args), 'AVG Train ACC': str(acc_train),
                     'AVG Test ACC': str(acc_test), 'AVG Train LOSS': str(loss_train), 'AVG Test LOSS': str(loss_test),
                     'time': total_time}))
    files.close()
    torch.save((acc_train, acc_test, loss_train, loss_test), os.path.join(save_path, file_name))
    print(args)
    print(f'time:{total_time} h')


if __name__ == "__main__":
    main()
