import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import torch.optim as optim
from torch import nn as nn

from train_loops.loss_factory import get_loss_function
from trainer.jocor import JoCoR
from utils import get_configs, get_datasets, get_logger, get_model, seed_everything

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def get_optimizer(config, model):
    optimizer_type, learning_rate = (
        config["train"]["optimizer_type"],
        config["train"]["learning_rate"],
    )

    if optimizer_type == "sgd":
        opt_ = optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=0.9,
            weight_decay=5e-4,
        )
    elif optimizer_type == "adam":
        opt_ = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
    elif optimizer_type == "rmsprop":
        opt_ = optim.RMSprop(model.parameters(), lr=learning_rate)
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")

    return opt_


def get_scheduler(config, optimizer):
    # config = full_package["config"]
    scheduler_type = config["train"]["scheduler_type"]

    if scheduler_type == "step":
        scheduler_gamma, scheduler_step_size = (
            config["train"]["scheduler_gamma"],
            config["train"]["scheduler_step_size"],
        )
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=scheduler_step_size,
            gamma=scheduler_gamma,
        )
    elif scheduler_type == "cosine":
        scheduler_T_max = config["train"]["scheduler_T_max"]
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=scheduler_T_max
        )
    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
    return scheduler


def main(config):
    print("==> Preparing data..")
    trainset, testset, valset, num_classes = get_datasets(config)
    model = get_model(config, num_classes)
    model2 = get_model(config, num_classes)
    model.to(device)
    model2.to(device)
    optimizer = get_optimizer(config, model)
    full_package = {}
    criterion = get_loss_function(config["train"]["loss_type"], full_package)[
        "criterion"
    ]
    logger = get_logger(config)
    trainer = JoCoR(
        config,
        model,
        model2,
        logger,
        trainset,
        testset,
        criterion,
        optimizer,
    )
    trainer.run()


if __name__ == "__main__":
    configs = get_configs()
    seed_everything(configs["general"]["np_seed"])
    main(configs)
