from trainer.train import *
from trainer.reweight_methods import *
import torch.optim as optim
from models import model
import matplotlib.pyplot as plt
import types
import os
from utils.util import *
from data_loader.cifar10 import load_cifar10
from data_loader.cifar100 import load_cifar100
from data_loader.clothing1m import load_clothing1m


config = {
    "data_loader": {"args": {"data_dir": "G:/datasets"}},
    "trainer": {
        "percent": 0.4,
        "asym": False,
        "instance": False,
        "seed": 0,
        "noise_file": None
    },
}

dataset = "cifar10"
if dataset=="clothing1m":
    config = {
        "batch_size": 64,#64
        "num_batches": 2000,
        "data_dir": "G:/datasets/Clothing1M/clothing1M"
    }
    train_loader, val_loader, test_loader, X_val, y_val = load_clothing1m(config, num_clean_val=2000)
    model = model.resnet50(pretrained=True, num_classes=14)
    epochs = 10
    optimizer = torch.optim.SGD(
                                model.parameters(),
                                lr=0.01,
                                momentum=0.9,
                                weight_decay=0.0005
                            )
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[5,7,9], last_epoch=- 1)
elif dataset=="cifar10":
    train_loader, val_loader, test_loader, X_val, y_val, train_clean_idx, train_noise_idx = load_cifar10(config, 2000)
    model = model.resnet34(num_classes=10)
    epochs = 150
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.02,
        momentum=0.9,
        weight_decay=0.001
    )
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[40, 80, 100], last_epoch=- 1)
elif dataset=="cifar100":
    train_loader, val_loader, test_loader, X_val, y_val, train_clean_idx, train_noise_idx = load_cifar100(config, 2000)
    model = model.resnet34(num_classes=100)
    epochs = 150
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=0.02,
        momentum=0.9,
        weight_decay=0.001
    )
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        milestones=[40, 80, 100], last_epoch=- 1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


model.reweight_method = types.MethodType(reweight_feature, model)
model = model.to(device)

w_list, loss, w_direction_list = train_loop(model, optimizer, train_loader, X_val, y_val, test_loader,
                          alpha=[1/9, 0],
                          num_epochs=150,  start_epoch=0, val_feature=True,
                          reweight_every=1, max_clip=1, clean_only=False,
                          reweight=True, args=device, test_every=1, lr_scheduler=lr_scheduler,
                          noisy_rate=None, recompute=True, val_steps=30, skip_epochs=10, correction=None
                          )