import torch
from numpy import round


from utils import add_adapters, set_active_task, freeze_base_thaw_adapters, freeze_head, instantiate_base_model
from copy import deepcopy
import argparse
from os import listdir
from numpy import argsort

def main(args):
    # Parse arguments
    model_name = args.model
    pretrained_base_path = args.path
    task = args.task
    device = args.device

    if len(pretrained_base_path) > 0:
        files = listdir(pretrained_base_path)

        if len(files) == 1:
            sorted_files = files
        else:
            nums = []
            for file in files:
                idx = file.find(".pt")
                num_str = file[idx-2:idx]
                if not num_str.isnumeric():
                    num_str = file[idx-1]
                nums.append(int(num_str))

            idx = argsort(nums)
            sorted_files = [files[i] for i in idx]

        for file in sorted_files:
            checkpoint_base = torch.load(pretrained_base_path + file,map_location="cpu")
            base_metrics = checkpoint_base["val_metrics"]
            epoch = checkpoint_base["epoch"]
            print(f"Epoch: {epoch} ---  Base Validation Metrics: {base_metrics}")

    return

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = "Run meta-lora algorithm using Roberta")
    valid_models = [
        "roberta-large",
        "roberta-base"
    ]
    valid_models.extend([m + "-openai-detector" for m in valid_models])
    parser.add_argument('--model', metavar='m', action = "store", choices = ["roberta-large","roberta-base","roberta-large-openai"], type=str, help = "base model to use", default="roberta-large")
    parser.add_argument("--path", action = "store", type=str, help = "path to folder with pretrained models",default = "")
    parser.add_argument('--task', metavar = '-t', action= "store", choices = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"], type=str, help = "glue task to use", default = "cola")
    parser.add_argument('--device', action="store",help="device to train on",default = "cuda:4")

    args = parser.parse_args()
    main(args)