import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from config import get_config
from data import get_source_domains
from model import get_model, save_model, print_model_info
from utils import set_seed, train_epoch, get_accuracy, get_optimizer


# load config
cf = get_config(["config/pretrain.yaml"])
cf.print()

def train(model, optimizer, tr, ts):
    tr_loader = DataLoader(tr, batch_size=cf.train.batch_size)
    ts_loader = DataLoader(ts, batch_size=cf.train.batch_size)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, threshold=1e-4, cooldown=0, min_lr=1e-8)
    for epoch in range(cf.train.epochs):
        train_loss = train_epoch(tr_loader, model, optimizer, cf.device)
        # train_acc, train_loss = get_accuracy(tr_loader, model, cf.device)
        test_acc, test_loss = get_accuracy(ts_loader, model, cf.device)
        print(f"Epoch {epoch+1}, train loss: {train_loss:.4f}, test acc: {test_acc:.4f}, test loss: {test_loss:.4f} lr: {optimizer.param_groups[0]['lr']:.6f}")
        scheduler.step(test_loss)
        if optimizer.param_groups[0]['lr'] < 1e-6:
            break

def main():
    gradual_domains = ["rotate_mnist", "color_mnist", "portraits"]
    models_name = ["cnn", "vgg", "resnet"]
    set_seed(cf.seed)
    for data_name, model_name in [(d, m) for d in gradual_domains for m in models_name]:
        tr, ts = get_source_domains(data_name)
        
        model = get_model(data_name, model_name).to(cf.device)
        print(f"Training {data_name} {model.name} model")
        optimizer = get_optimizer(model, cf.train.optimizer, cf.train.learning_rate, cf.train.weight_decay, cf.train.momentum)
        
        train(model, optimizer, tr, ts)
        save_model(model, data_name, model_name)
        
def pretrain_fc():
    data_name = "covertype"
    model_name = "fc"
    tr, ts = get_source_domains(data_name)
    model = get_model(data_name, model_name).to(cf.device)
    print(f"Training {data_name} {model.name} model")
    optimizer = get_optimizer(model, cf.train.optimizer, cf.train.learning_rate, cf.train.weight_decay, cf.train.momentum)
    train(model, optimizer, tr, ts)
    save_model(model, data_name, model_name)
        
if __name__ == "__main__":
    # main()
    pretrain_fc()