import os
import sys

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))


from datasets.clothing1mpp import Clothing1mPP
from train_loops.loss_factory import LDAMLoss
from trainer.base import Trainer
from utils import get_configs, get_datasets, get_logger, get_model, seed_everything


def get_optimizer(config, model):
    optimizer_type, learning_rate = (
        config["train"]["optimizer_type"],
        config["train"]["learning_rate"],
    )

    if optimizer_type == "sgd":
        opt_ = optim.SGD(
            model.parameters(),
            lr=learning_rate,
            momentum=0.9,
            weight_decay=5e-4,
        )
    elif optimizer_type == "adam":
        opt_ = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)
    elif optimizer_type == "rmsprop":
        opt_ = optim.RMSprop(model.parameters(), lr=learning_rate)
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")

    return opt_


def get_scheduler(config, optimizer):
    # config = full_package["config"]
    scheduler_type = config["train"]["scheduler_type"]

    if scheduler_type == "step":
        scheduler_gamma, scheduler_step_size = (
            config["train"]["scheduler_gamma"],
            config["train"]["scheduler_step_size"],
        )
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=scheduler_step_size,
            gamma=scheduler_gamma,
        )
    elif scheduler_type == "cosine":
        scheduler_T_max = config["train"]["scheduler_T_max"]
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=scheduler_T_max
        )
    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
    return scheduler


def main(config):
    print("==> Preparing data..")
    if config["data"]["dataset_name"] == "clothing1mpp":
        train_dataset_full = Clothing1mPP(
            config["data"]["root"], config["data"]["image_size"], split="train"
        )

    cls_num_list = train_dataset_full.get_cls_num_list(
        seed=config["general"]["np_seed"],
        imbalance_factor=config["data"]["imbalance_factor"],
    )

    idx = config["train"]["epochs"] // 160
    betas = [0, 0.9999]
    effective_num = 1.0 - np.power(betas[idx], cls_num_list)
    per_cls_weights = (1.0 - betas[idx]) / np.array(effective_num)
    per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
    per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()

    # initalize loss
    criterion = LDAMLoss(
        cls_num_list=cls_num_list, max_m=0.5, s=30, weight=per_cls_weights
    ).cuda()

    trainset, testset, valset, num_classes = get_datasets(config)

    model = get_model(config, num_classes)
    optimizer = get_optimizer(config, model)
    scheduler = get_scheduler(config, optimizer)
    logger = get_logger(config)
    trainer = Trainer(
        config,
        model,
        logger,
        trainset,
        testset,
        criterion,
        optimizer,
        scheduler,
        valset,
    )
    trainer.run()


if __name__ == "__main__":
    configs = get_configs()
    seed_everything(configs["general"]["np_seed"])
    main(configs)
