import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from loader import *

from models.S2_KAUN import Model
# from models.S2_MLP import Model
# from models.U_KAN import Model
# from models.UNeXt import Model
# from models.U_Net import Model
# from models.U_Net_plus2 import Model
# from models.UtraLight_VM_UNet import Model
# from models.Rolling_UNet import Model
# from models.U_Net import Model
# from models.UNeXt import Model
# from models.DPMNet import Model
# from models.OCTA_Net import Model
# from models.TransUNet import Model
# from models.U_Net_ultra import Model

from engine import *
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3"

from utils import *
from configs.config_setting import setting_config

import warnings
warnings.filterwarnings("ignore")


def main(config):

    sys.path.append(config.work_dir + '/')
    log_dir = os.path.join(config.work_dir, 'log')
    checkpoint_dir = os.path.join(config.work_dir, 'checkpoints')
    resume_model = os.path.join(checkpoint_dir, 'latest.pth')
    outputs = os.path.join(config.work_dir, 'outputs')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if not os.path.exists(outputs):
        os.makedirs(outputs)

    global logger
    logger = get_logger('train', log_dir)

    log_config_info(config, logger)

    set_seed(config.seed)
    gpu_ids = [0]# [0, 1, 2, 3]
    torch.cuda.empty_cache()

    train_dataset = dataset_loader(path_Data = config.data_path, train = True)
    train_loader = DataLoader(train_dataset,
                                batch_size=config.batch_size, 
                                shuffle=True,
                                pin_memory=True,
                                num_workers=config.num_workers)
    val_dataset = dataset_loader(path_Data = config.data_path, train = False)
    val_loader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                shuffle=False,
                                pin_memory=True, 
                                num_workers=config.num_workers,
                                drop_last=True)
    test_dataset = dataset_loader(path_Data = config.data_path, train = False, Test = True)
    test_loader = DataLoader(test_dataset,
                                batch_size=config.batch_size,
                                shuffle=False,
                                pin_memory=True, 
                                num_workers=config.num_workers,
                                drop_last=True)

    model_cfg = config.model_config
    model = Model(num_classes=model_cfg['num_classes'], 
                               input_channels=model_cfg['input_channels'])
    
    model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])

    criterion = config.criterion
    optimizer = get_optimizer(config, model)
    scheduler = get_scheduler(config, optimizer)
    scaler = GradScaler()

    min_loss = 999
    start_epoch = 1
    min_epoch = 1
    best_dice = 0

    print('#----------Training----------#')
    for epoch in range(start_epoch, config.epochs + 1):

        torch.cuda.empty_cache()

        train_one_epoch(
            train_loader,
            model,
            criterion,
            optimizer,
            scheduler,
            epoch,
            logger,
            config,
            scaler=scaler
        )

        loss, metrics = val_one_epoch(
                val_loader,
                model,
                criterion,
                epoch,
                logger,
                config,
                'val'
            )

        if metrics["mean_dice"] > best_dice:
            print('epoch:', epoch, '\n',
                ' miou:', metrics["mean_iou"], 
                ' iou:', [m["iou"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_dice:', metrics["mean_dice"], 
                ' dice:', [m["dice"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_acc:', metrics["mean_pixel_accuracy"], 
                ' acc:', [m["pixel_accuracy"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_assd:', metrics["mean_assd"], '\n',
                ' best dice:', metrics["mean_dice"])
            best_dice = metrics["mean_dice"]
        else:
            print('epoch:', epoch, '\n',
                ' miou:', metrics["mean_iou"], 
                ' iou:', [m["iou"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_dice:', metrics["mean_dice"], 
                ' dice:', [m["dice"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_acc:', metrics["mean_pixel_accuracy"], 
                ' acc:', [m["pixel_accuracy"] for m in metrics["class_metrics"][1:]], '\n',
                ' mean_assd:', metrics["mean_assd"], '\n',
                ' best dice:', best_dice)

        if loss < min_loss:
            torch.save(model.module.state_dict(), os.path.join(checkpoint_dir, 'best.pth'))
            min_loss = loss
            min_epoch = epoch

        if epoch % 30 == 0 or epoch == config.epochs + 1:
            torch.save(
                {
                    'epoch': epoch,
                    'min_loss': min_loss,
                    'min_epoch': min_epoch,
                    'loss': loss,
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                }, os.path.join(checkpoint_dir, 'latest.pth')) 

    if os.path.exists(os.path.join(checkpoint_dir, 'best.pth')):
        print('#----------Testing----------#')
        best_weight = torch.load(config.work_dir + 'checkpoints/best.pth', map_location=torch.device('cpu'))
        model.module.load_state_dict(best_weight)
        loss = predict_one_epoch(
                test_loader,
                model,
                criterion,
                logger,
                config,
        )    


if __name__ == '__main__':
    config = setting_config
    main(config)