import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from loader import *
from thop import profile
from thop import clever_format

from models.S2_KAUN import Model
# from models.S2_MLP import Model
# from models.U_KAN import Model
# from models.Rolling_UNet import Model
# from models.U_Net import Model
# from models.UNeXt import Model
# from models.DPMNet import Model
# from models.OCTA_Net import Model
# from models.TransUNet import Model

model_path = 'tmp/unet/checkpoints/best.pth'

from engine import *
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # "0, 1, 2, 3"

from utils import *
from configs.config_setting import setting_config

import warnings
warnings.filterwarnings("ignore")


def main(config):

    # print('#----------Creating logger----------#')
    sys.path.append(config.work_dir + '/')
    log_dir = os.path.join(config.work_dir, 'log')
    checkpoint_dir = os.path.join(config.work_dir, 'checkpoints')
    resume_model = os.path.join('')
    outputs = os.path.join(config.work_dir, 'outputs')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if not os.path.exists(outputs):
        os.makedirs(outputs)

    global logger
    logger = get_logger('test', log_dir)

    log_config_info(config, logger)

    # print('#----------GPU init----------#')
    set_seed(config.seed)
    gpu_ids = [0]# [0, 1, 2, 3]
    torch.cuda.empty_cache()
    
    # print('#----------Prepareing Models----------#')
    model_cfg = config.model_config    
    model = Model(num_classes=model_cfg['num_classes'], 
                               input_channels=model_cfg['input_channels'])
    
    # input = torch.randn(1, 3, 256, 256)

    # # Measure FLOPs and params
    # flops, params = profile(model, inputs=(input, ))
    # flops, params = clever_format([flops, params], "%.3f")
    # print(f"FLOPs: {flops}, Parameters: {params}")
    # import time
    # # Measure inference time
    # start_time = time.time()
    # with torch.no_grad():
    #     _ = model(input)
    # end_time = time.time()

    # # Calculate inference time
    # inference_time_ms = (end_time - start_time) * 1000
    # print(f"Inference time: {inference_time_ms:.3f} ms")

    # # Print number of trainable parameters in millions
    # num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000
    # print(f"Trainable Parameters: {num_params:.3f} M")
    
    model = torch.nn.DataParallel(model.cuda(), device_ids=gpu_ids, output_device=gpu_ids[0])


    # print('#----------Preparing dataset----------#')
    test_dataset = dataset_loader(path_Data = config.data_path, train = False, Test = True)
    test_loader = DataLoader(test_dataset,
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True, 
                                num_workers=config.num_workers,
                                drop_last=True)

    # print('#----------Prepareing loss, opt, sch and amp----------#')
    criterion = config.criterion
    optimizer = get_optimizer(config, model)
    scheduler = get_scheduler(config, optimizer)
    scaler = GradScaler()

    # print('#----------Set other params----------#')
    min_loss = 999
    start_epoch = 1
    min_epoch = 1

    print('#----------Testing----------#')
    best_weight = torch.load(model_path)
    model.module.load_state_dict(best_weight)
    loss = predict_one_epoch(
            test_loader,
            model,
            criterion,
            logger,
            config,
        )

if __name__ == '__main__':
    config = setting_config
    main(config)