import os.path

from transformers import AutoModelForSequenceClassification, AdamW, get_scheduler
from torch.utils.data import DataLoader
from train_eval import train, setup_seed
from nlp_dataset import get_dataset, get_num_class
from utils import PathConfig
from args import get_fe_args
from utils import split_train_val, read_config


args = get_fe_args()
PC = PathConfig()
cfg = read_config(cfg_path=PC.get_dataset_config_path() + args.dst_name + '.yaml')


def load_feature_extractor(weight_path):
    model = AutoModelForSequenceClassification.from_pretrained(weight_path, output_hidden_states=True)
    return model


def train_feature_extractor(device, model_name, dataset_name, num_classes, model_save_path):
    whole_train_dst, _ = get_dataset(dataset_name, model_name)

    if args.fe_type == 'default':
        args.k = 1
    else:
        train_val_index_list = split_train_val([i for i in range(len(whole_train_dst))], whole_train_dst['labels'],
                                               seed=cfg['split_seed'], k=args.k, val_ratio=0.2)

    for k_index in range(args.k):
        if args.k == 1:
            train_dst = whole_train_dst
        else:
            train_dst = whole_train_dst.select(train_val_index_list[k_index][0])

        model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)
        model.to(device)
        train_dst.set_format("torch")
        train_dataloader = DataLoader(train_dst, shuffle=True, batch_size=args.batch_size)
        optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        num_epochs = args.num_epochs
        lr_scheduler = get_scheduler(
            "linear",
            optimizer=optimizer,
            num_warmup_steps=0,
            num_training_steps=num_epochs * len(train_dataloader)
        )
        _, _, _, model = train(model, train_dataloader, None, None, optimizer, lr_scheduler, num_epochs, device)

        curr_save_path = os.path.join(model_save_path, 'default' if args.k==1 else str(k_index))
        if not os.path.exists(curr_save_path):
            os.makedirs(curr_save_path)
        model.save_pretrained(save_directory=curr_save_path)


def main():
    dataset_name = args.dst_name
    device = "cuda:0"
    model_name = "bert-base-uncased"
    num_classes = get_num_class(dataset_name)
    train_feature_extractor(device, model_name, dataset_name, num_classes,
                            model_save_path=PathConfig().get_fe_path(dataset_name))

# CUDA_VISIBLE_DEVICES=4 python train_feature_extractor_main.py --num_epochs 20 --batch_size 8 --dst_name imdb --fe_type default
# CUDA_VISIBLE_DEVICES=7 python train_feature_extractor_main.py --num_epochs 50 --batch_size 8 --dst_name newsgroups --fe_type default
# CUDA_VISIBLE_DEVICES=7 python train_feature_extractor_main.py --num_epochs 2 --batch_size 8 --dst_name newsgroups --fe_type kfold --k 5
if __name__ == '__main__':
    setup_seed(PC.get_global_seed())
    main()