from utils.arguments import get_args
from utils.utils import *
from utils.transforms import *
from utils.get_data import *
from models.resnet import *
from models.cnn_digits import *
from core import pretrain, train
import torch.utils.data as data
import warnings
import time

warnings.filterwarnings("ignore", category=UserWarning) 

if __name__ == '__main__':

    args = get_args()

    print(args.title)

    if torch.cuda.is_available():
        args.device = torch.device("cuda")
    else:
        raise RuntimeError("CUDA is not available on this system!")
    
    start_time = time.time()

    if os.path.exists(args.log_file):
        i = 1
        while os.path.exists(f"{args.log_file[:-4]}({i}).txt"):
            i += 1
        args.log_file = f"{args.log_file[:-4]}({i}).txt"
        with open(args.log_file, 'w') as f:
            pass

    write_log(args, f'Device: {args.device}, Count of using GPUs: {torch.cuda.device_count()}')
    
    init_random_seed(args.seed)

    root = '{}/{}/'.format(args.data_dir, args.dataset)
    
    domains = args.domains[args.dataset]    # ex) [a, c, p, s]
    domain_order = []
    for i in args.order:    # ex) [2, 0, 1, 3]
        domain_order.append(domains[i])    # ex) [p, a, c, s]
    
    if args.transforms == 'simple':
        transform_train = simple_transform_train(args.input_size, args.dataset)
    elif args.transforms == 'simpleV2':
        transform_train = simpleV2_transform_train(args.input_size, args.dataset)
    transform_test = simple_transform_test(args.input_size, args.dataset)

    if args.dataset == 'PACS':
        # build datalist
        src_datalist_train = []
        for i in range(args.num_classes):
            src_datalist_train.append(get_pacs_image_dirs_label(root=root, dname=domain_order[0], split='train', label=i))
        
        src_datalist_pretrain = get_pacs_image_dirs(root=root, dname=domain_order[0], split='train')
        src_datalist_valid = get_pacs_image_dirs(root=root, dname=domain_order[0], split='crossval')
        ul1_datalist_train = get_pacs_image_dirs(root=root, dname=domain_order[1], split='train')
        ul2_datalist_train = get_pacs_image_dirs(root=root, dname=domain_order[2], split='train')
        ul1_datalist_valid = get_pacs_image_dirs(root=root, dname=domain_order[1], split='crossval')
        ul2_datalist_valid = get_pacs_image_dirs(root=root, dname=domain_order[2], split='crossval')
        test_datalist = get_pacs_image_dirs(root=root, dname=domain_order[3], split='test')

        # build dataset
        src_dataset_train = []
        for i in range(args.num_classes):
            src_dataset_train.append(base_dataset(src_datalist_train[i], transform=transform_train, dl=0, labeled=True))

        src_dataset_pretrain = base_dataset(src_datalist_pretrain, transform=transform_train, dl=0, labeled=True)
        if args.proto_notrans:
            src_dataset_proto = base_dataset(src_datalist_pretrain, transform=transform_test, dl=0, labeled=True)
        else:
            src_dataset_proto = base_dataset(src_datalist_pretrain, transform=transform_train, dl=0, labeled=True)
        src_dataset_valid = base_dataset(src_datalist_valid, transform=transform_test, dl=0, labeled=True)
        ul1_dataset_train = base_dataset(ul1_datalist_train, transform=transform_train, dl=1, labeled=False)
        ul2_dataset_train = base_dataset(ul2_datalist_train, transform=transform_train, dl=2, labeled=False)
        ul_dataset_train = ul1_dataset_train + ul2_dataset_train
        if args.proto_notrans:
            ul1_dataset_train_proto = base_dataset(ul1_datalist_train, transform=transform_test, dl=1, labeled=True)
            ul2_dataset_train_proto = base_dataset(ul2_datalist_train, transform=transform_test, dl=2, labeled=True)
        else:
            ul1_dataset_train_proto = base_dataset(ul1_datalist_train, transform=transform_train, dl=1, labeled=True)
            ul2_dataset_train_proto = base_dataset(ul2_datalist_train, transform=transform_train, dl=2, labeled=True)
        ul1_dataset_valid = base_dataset(ul1_datalist_valid, transform=transform_test, dl=1, labeled=True)
        ul2_dataset_valid = base_dataset(ul2_datalist_valid, transform=transform_test, dl=2, labeled=True)
        test_dataset = base_dataset(test_datalist, transform=transform_test, dl=0, labeled=True)

    else:
        # build datalist
        src_datalist_train = []
        for i in range(args.num_classes):
            src_datalist_train.append(get_image_dirs_label(root=root, dname=domain_order[0], split='train', label=i))
        
        src_datalist_pretrain = get_image_dirs(root=root, dname=domain_order[0], split='train')
        src_datalist_valid = get_image_dirs(root=root, dname=domain_order[0], split='val')
        ul1_datalist_train = get_image_dirs(root=root, dname=domain_order[1], split='train')
        ul2_datalist_train = get_image_dirs(root=root, dname=domain_order[2], split='train')
        ul1_datalist_valid = get_image_dirs(root=root, dname=domain_order[1], split='val')
        ul2_datalist_valid = get_image_dirs(root=root, dname=domain_order[2], split='val')
        test_datalist_train = get_image_dirs(root=root, dname=domain_order[3], split='train')
        test_datalist_valid = get_image_dirs(root=root, dname=domain_order[3], split='val')

        # build dataset
        src_dataset_train = []
        for i in range(args.num_classes):
            src_dataset_train.append(base_dataset(src_datalist_train[i], transform=transform_train, dl=0, labeled=True))

        src_dataset_pretrain = base_dataset(src_datalist_pretrain, transform=transform_train, dl=0, labeled=True)
        if args.proto_notrans:
            src_dataset_proto = base_dataset(src_datalist_pretrain, transform=transform_test, dl=0, labeled=True)
        else:
            src_dataset_proto = base_dataset(src_datalist_pretrain, transform=transform_train, dl=0, labeled=True)
        src_dataset_valid = base_dataset(src_datalist_valid, transform=transform_test, dl=0, labeled=True)
        ul1_dataset_train = base_dataset(ul1_datalist_train, transform=transform_train, dl=1, labeled=False)
        ul2_dataset_train = base_dataset(ul2_datalist_train, transform=transform_train, dl=2, labeled=False)
        ul_dataset_train = ul1_dataset_train + ul2_dataset_train
        if args.proto_notrans:
            ul1_dataset_train_proto = base_dataset(ul1_datalist_train, transform=transform_test, dl=1, labeled=True)
            ul2_dataset_train_proto = base_dataset(ul2_datalist_train, transform=transform_test, dl=2, labeled=True)
        else:
            ul1_dataset_train_proto = base_dataset(ul1_datalist_train, transform=transform_train, dl=1, labeled=True)
            ul2_dataset_train_proto = base_dataset(ul2_datalist_train, transform=transform_train, dl=2, labeled=True)
        ul1_dataset_valid = base_dataset(ul1_datalist_valid, transform=transform_test, dl=1, labeled=True)
        ul2_dataset_valid = base_dataset(ul2_datalist_valid, transform=transform_test, dl=2, labeled=True)
        test_dataset_train = base_dataset(test_datalist_train, transform=transform_test, dl=0, labeled=True)
        test_dataset_valid = base_dataset(test_datalist_valid, transform=transform_test, dl=0, labeled=True)
        test_dataset = test_dataset_train + test_dataset_valid

    # build dataloader
    src_data_loader_train = []
    for i in range(args.num_classes):
        src_data_loader_train.append(InfiniteDataLoader(src_dataset_train[i], batch_size=1, num_workers=0, shuffle=True))

    src_data_loader_pretrain = data.DataLoader(src_dataset_pretrain, batch_size=args.batch_size, num_workers=0, shuffle=True, drop_last=True)
    src_data_loader_proto = data.DataLoader(src_dataset_proto, batch_size=args.batch_size, num_workers=0, shuffle=False, drop_last=False)
    src_data_loader_valid = data.DataLoader(src_dataset_valid, batch_size=args.batch_size, num_workers=0, shuffle=False, drop_last=False)
    ul_data_loader_train = data.DataLoader(ul_dataset_train, batch_size=args.batch_size, num_workers=0, shuffle=True, drop_last=True)
    ul1_data_loader_valid = data.DataLoader(ul1_dataset_valid, batch_size=args.batch_size, num_workers=0, shuffle=False, drop_last=False)
    ul2_data_loader_valid = data.DataLoader(ul2_dataset_valid, batch_size=args.batch_size, num_workers=0, shuffle=False, drop_last=False)
    test_data_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=0, shuffle=False, drop_last=False)
    
    # load models 
    if args.dataset == 'Digits':
        src_featurizer = init_model(net=CNN_Digits(), restore=args.src_featurizer_restore)
        dg_featurizer = init_model(net=CNN_Digits(), restore=args.src_featurizer_restore)
    else:
        src_featurizer = init_model(net=resnet18(args), restore=args.src_featurizer_restore)
        dg_featurizer = init_model(net=resnet18(args), restore=args.src_featurizer_restore)
    feature_dim = src_featurizer.fdim
    
    src_classifier = init_model(net=Classifier(args, feature_dim, args.classifier_temp), restore=args.src_classifier_restore)
    dg_classifier = init_model(net=Classifier(args, feature_dim, args.classifier_temp), restore=args.src_classifier_restore)

    # pretrain source model
    if not (src_featurizer.restored and src_classifier.restored):
        print("=== Pretraining on source domain ===")

        src_featurizer, src_classifier = pretrain.train_src(args, src_featurizer, src_classifier, src_data_loader_pretrain, src_data_loader_valid, ul1_data_loader_valid, ul2_data_loader_valid, test_data_loader)

    # eval source model
    """
    print("=== Evaluating on source domain ===")
    pretrain.eval_src(src_featurizer, src_classifier, src_data_loader_valid, True)
    print("=== Evaluating on unlabeled domain1 ===")
    pretrain.eval_src(src_featurizer, src_classifier, ul1_data_loader_valid, True)
    print("=== Evaluating on unlabeled domain2 ===")
    pretrain.eval_src(src_featurizer, src_classifier, ul2_data_loader_valid, True)
    print("=== Evaluating on test domain ===")
    pretrain.eval_src(src_featurizer, src_classifier, test_data_loader, True)
    """

    if not (dg_featurizer.restored and dg_classifier.restored):
        dg_featurizer.load_state_dict(src_featurizer.state_dict())
        dg_classifier.load_state_dict(src_classifier.state_dict())

    src_acc, ul1_acc, ul2_acc, test_acc, lam1_list, lam2_list, ul1_pl_acc, ul2_pl_acc = train.train(args, dg_featurizer, dg_classifier, src_data_loader_proto, src_data_loader_train, src_data_loader_valid, ul1_dataset_train_proto, ul2_dataset_train_proto, ul1_data_loader_valid, ul2_data_loader_valid, test_data_loader, transform_train)

    make_plot(args, src_acc, ul1_acc, ul2_acc, test_acc, lam1_list, lam2_list, ul1_pl_acc, ul2_pl_acc, mid=False)

end_time = time.time()
print_time(start_time, end_time)