import numpy as np
from utilities.Logger import Logger
from utilities.train import *
from utilities.GLOBAL_VALUE import SEEDS


args = get_args()
logger = Logger(args)
# check availbility of device:
device = torch.device("cpu")
print("\033[92m" + f"Device: {device}" + "\033[0m")
if args["num_seed"] > len(SEEDS):
    raise ValueError(
        "Number of seeds specified is greater than the seed list, append more custom seed into the list"
    )
else:
    seeds = SEEDS[: args["num_seed"]]
args["input_size"] = 32
init_transform = get_transform(
    args["augment"], args["input_size"], args["model"]
)
encoder_past = load_model(args, device)
all_models = [
    "lwp",
    "lwf",
    "der",
    "derpp",
    "er",
    "fdr",
    "gss",
    "si",
    "single",
    "mtl",
    "pcgrad",
    "imtl",
    "nashmtl",
]

if args["dataset"] == "celeba":
    prediction_targets = CelebADataset.prediction_targets[:10]
elif args["dataset"] == "physiq":
    prediction_targets = PhysiQDataset.prediction_targets
elif args["dataset"] == "fairface":
    prediction_targets = FairFaceDataset.prediction_targets

n_tasks = 10

alls = []
for each_model in all_models:
    args["model"] = each_model
    n_backbone = sum(p.numel() for p in encoder_past.parameters())

    n_pred = sum(p.numel() for p in torch.nn.Linear(512, 2).parameters())
    if each_model == "lwp":
        print(
            f"Model: {each_model}, total params: {n_backbone + n_tasks *n_pred}"
        )
        total = n_backbone + n_tasks * n_pred
    elif each_model == "single":
        print(f"Model: {each_model}, total params: {n_backbone + n_tasks}")
        total = (n_backbone + n_pred) * n_tasks
    elif each_model in [
        "lwf",
        "der",
        "derpp",
        "er",
        "fdr",
        "gss",
        "si",
    ]:
        print(f"Model: {each_model}, total params: {n_backbone + n_tasks}")
        total = n_backbone + n_tasks
        model = get_models(
            args,
            encoder_past,
            cls_output_dim=args["cls_output_dim"],
            lr=args["lr"],
            input_size=args["input_size"],
            dataset_name=args["dataset"],
            buffer_size=args["buffer_size"],
            num_tasks=len(prediction_targets),
            z_dim=args["z_dim"],
            n_epochs=args["epochs"],
            device=device,
        )
        path = f"./saved_models/CL/{args['dataset']}/{args['model']}/"
        path += f"{SEEDS[0]}_Straight_Hair_best_enc.pt"
        print(path)
        model.load_state_dict(
            torch.load(path, weights_only=False, map_location="cpu"),
            strict=False,
        )
        total = sum(torch.count_nonzero(p) for p in model.parameters())

    else:
        model = get_models_MTL(
            args,
            encoder_past,
            tasks_name_to_cls_num={
                task_name: 2 for task_name in prediction_targets
            },
            z_dim=args["z_dim"],
            cls_output_dim=args["cls_output_dim"],
        )
        # {args['job'].upper()}
        path = f"./saved_models/MTL/{args['dataset']}/{args['model']}/"
        # if args["job"] == "cl":
        #     path += f"{SEEDS[0]}_Straight_Hair_best_enc.pt"
        # else:
        path += f"{SEEDS[0]}_best_enc.pt"
        model.load_state_dict(
            torch.load(path, weights_only=False, map_location="cpu")
        )

        print(f"Model: {each_model}, total params: ", end="")

        # [predictor for predictor in model.predictors.values()]
        total = sum(torch.count_nonzero(p) for p in model.parameters()) + sum(
            p.numel()
            for predictor in model.predictors.values()
            for p in predictor.parameters()
        )

        print(total)
    alls.append(total)

alls = np.array(alls)
# percentage of min max:
min_max = alls.max()
print(min_max)
alls = alls / min_max * 1000.0
# 2 decimal places
print([round(each, 3) for each in alls])
