import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from networks.benthiq import BenthIQ
import trainer

parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
                    default='train_test_split', help='root dir for data')
parser.add_argument('--dataset', type=str,
                    default='Coral', help='dataset name')
parser.add_argument('--num_classes', type=int,
                    default=4, help='output channel of network')
parser.add_argument('--output_dir', type=str, default="output", help='output dir for model weights')                   
parser.add_argument('--max_iterations', type=int,
                    default=10000, help='maximum epoch number to train') # 30000
parser.add_argument('--max_epochs', type=int,
                    default=500, help='maximum epoch number to train') #150
parser.add_argument('--batch_size', type=int,
                    default=24, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int,  default=1,
                    help='whether use deterministic training')
parser.add_argument('--base_lr', type=float,  default=0.01,
                    help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
                    default=224, help='input patch size of network input')
parser.add_argument('--seed', type=int,
                    default=1234, help='random seed')
parser.add_argument('--pretrain_ckpt', type=str, default='pretrained_ckpt/swin_t.pth', help='checkpoint to resume from')
parser.add_argument('--use-checkpoint', action='store_true',
                    help="whether to use gradient checkpointing to save memory")

args = parser.parse_args()

def setup_seed(seed):
    """
    Set random seed for reproducibility.

    Args:
        seed (int): Random seed value.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

if __name__ == "__main__":
    if not args.deterministic:
        cudnn.benchmark = True
        cudnn.deterministic = False
    else:
        cudnn.benchmark = False
        cudnn.deterministic = True

    setup_seed(args.seed)

    dataset_name = args.dataset

    if args.batch_size != 24 and args.batch_size % 6 == 0:
        args.base_lr *= args.batch_size / 24

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    net = BenthIQ(img_size=args.img_size, num_classes=args.num_classes).cuda()
    net.load_from(args.pretrain_ckpt)

    trainer = {'Coral': trainer}
    trainer[dataset_name](args, net, args.output_dir)