import random
import torch
import torch.optim as optim

from models import FNetwork, PNetwork
from trainers import pretrainer
from evals import test_few_shot, test_few_shot_ensemble
from data import get_loaders
from common import parse_args, set_seed


def main():
    P = parse_args()
    if P.seed != 0:
        set_seed(P.seed)
        P.index = P.seed
    else:
        P.index = random.randint(1, 1000000)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    P.device = device

    train_loader, val_loader, test_loader, feature_groups = get_loaders(
        P, P.dataset, P.batch_size, P.seed
    )
    print("Train size:", len(train_loader.dataset))
    print("Val   size:", len(val_loader.dataset))
    print("Test  size:", len(test_loader.dataset))
    f_model = FNetwork(P.input_dim, P.hidden_dim).to(P.device)
    p_model = PNetwork(P.hidden_dim, P.input_dim, P.embed_dim).to(P.device)

    optimizer = optim.Adam(
        list(f_model.parameters()) + list(p_model.parameters()),
        lr=P.learning_rate,
    )

    if P.ensemble:
        print("Ensemble mode")
        ratios = [0.1, 0.2, 0.3, 0.4, 0.5]
        for r in ratios:
            P.masked_ratio = r

            f_model = FNetwork(P.input_dim, P.hidden_dim).to(P.device)
            p_model = PNetwork(P.hidden_dim, P.input_dim, P.embed_dim).to(P.device)

            optimizer = optim.Adam(
                list(f_model.parameters()) + list(p_model.parameters()),
                lr=P.learning_rate,
            )

            print(f"[Pretrain] masked_ratio={r}, filling mode={P.fill_mode}")
            pretrainer(
                P, f_model, p_model, optimizer, train_loader, val_loader, feature_groups
            )

            del f_model
            del p_model
            del optimizer
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

        avg_accuracy = test_few_shot_ensemble(P, test_loader.dataset)
    else:
        print("Single model mode")
        print(f"[Pretrain] masked_ratio={P.masked_ratio}, filling mode={P.fill_mode}")
        pretrainer(
            P, f_model, p_model, optimizer, train_loader, val_loader, feature_groups
        )

        avg_accuracy = test_few_shot(P, P.masked_ratio, test_loader.dataset)

    print(f"{avg_accuracy:.4f}")


if __name__ == "__main__":
    main()
