# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------

import os
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import util.misc as misc   
import torch

PRESET_STATS = {
    "APTOS":([0.3996780514717102, 0.21280449628829956, 0.06569761782884598], [0.2825203537940979, 0.15452957153320312, 0.07784931361675262]),
    "IDRiD":([0.4525451362133026, 0.22106815874576569, 0.07328230887651443], [0.30778563022613525, 0.16519124805927277, 0.0865519568324089]),
    "Messidor2":([0.4852456748485565, 0.22558310627937317, 0.07774920761585236], [0.2954258322715759, 0.1485498547554016, 0.06574897468090057]),
}

def build_dataset(is_train, args, domain=None):
    transform = build_transform(is_train, args)
    if domain is None:
        root = os.path.join(args.data_path, is_train)
    else:
        root = os.path.join(args.data_path, domain, is_train)
    
    # --- MODIFICATION START ---
 
    if not os.path.exists(root) and is_train == 'test':
        print(f"WARNING: Test directory not found for domain {domain}, using val dataset as test")
        if domain is None:
            root = os.path.join(args.data_path, 'val')
        else:
            root = os.path.join(args.data_path, domain, 'val')
    # --- MODIFICATION END ---
    
    dataset = datasets.ImageFolder(root, transform=transform)
    return dataset

def build_continual_dataloader(args):
    dataloader = list()
    if getattr(args, 'domains', None):
        dataset_list = args.domains
    elif getattr(args, 'dataset', None):
        dataset_list = args.dataset.split(',')
    else:
        raise ValueError("No domains specified: set --domains or --dataset")
    
    for domain in dataset_list:
        dataset_train = build_dataset(is_train='train', args=args, domain=domain)
        dataset_val   = build_dataset(is_train='val',   args=args, domain=domain)
        
        # --- MODIFICATION START ---

        test_path = os.path.join(args.data_path, domain, 'test') if domain else os.path.join(args.data_path, 'test')
        if os.path.exists(test_path):
            dataset_test = build_dataset(is_train='test', args=args, domain=domain)
            print(f"INFO: Using separate test dataset for domain: {domain}")
        else:
            dataset_test = dataset_val 
            print(f"INFO: No test directory found for domain {domain}, using val dataset as test")
        # --- MODIFICATION END ---

        # -------------------- samplers --------------------
        if args.distributed and misc.get_world_size() > 1:
            num_tasks    = misc.get_world_size()
            global_rank  = misc.get_rank()

            sampler_train = torch.utils.data.DistributedSampler(
                dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True)

            # For validation / test we mimic the logic in main_finetune
            if args.dist_eval:
                sampler_val  = torch.utils.data.DistributedSampler(
                    dataset_val,  num_replicas=num_tasks, rank=global_rank, shuffle=False)
                sampler_test = torch.utils.data.DistributedSampler(
                    dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=False)
            else:
                sampler_val  = torch.utils.data.SequentialSampler(dataset_val)
                sampler_test = torch.utils.data.SequentialSampler(dataset_test)
        else:
            sampler_train = torch.utils.data.RandomSampler(dataset_train)
            sampler_val   = torch.utils.data.SequentialSampler(dataset_val)
            sampler_test  = torch.utils.data.SequentialSampler(dataset_test)

        # -------------------- loaders --------------------
        data_loader_train = torch.utils.data.DataLoader(
            dataset_train, sampler=sampler_train,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=True,
        )
        data_loader_val = torch.utils.data.DataLoader(
            dataset_val, sampler=sampler_val,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False,
        )
        data_loader_test = torch.utils.data.DataLoader(
            dataset_test, sampler=sampler_test,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            pin_memory=args.pin_mem,
            drop_last=False,
        )

        dataloader.append(
            {'train': data_loader_train,
             'val':   data_loader_val,
             'test':  data_loader_test}
        )
        
    return dataloader

def build_joint_dataloader(args):
 
    if getattr(args, 'domains', None):
        dataset_list = args.domains
    elif getattr(args, 'dataset', None):
        dataset_list = args.dataset.split(',')
    else:
        raise ValueError("No domains specified: set --domains or --dataset")
    

    datasets_train = []
    datasets_val = []
    datasets_test = []
    
    for domain in dataset_list:
        print(f"INFO: Building dataset for domain: {domain}")
        dataset_train = build_dataset(is_train='train', args=args, domain=domain)
        dataset_val   = build_dataset(is_train='val',   args=args, domain=domain)
        
        # --- MODIFICATION START ---

        test_path = os.path.join(args.data_path, domain, 'test') if domain else os.path.join(args.data_path, 'test')
        if os.path.exists(test_path):
            dataset_test = build_dataset(is_train='test', args=args, domain=domain)
            print(f"INFO: Using separate test dataset for domain: {domain}")
        else:
            dataset_test = dataset_val  #
            print(f"INFO: No test directory found for domain {domain}, using val dataset as test")
        # --- MODIFICATION END ---
        
        datasets_train.append(dataset_train)
        datasets_val.append(dataset_val)
        datasets_test.append(dataset_test)
    

    joint_dataset_train = torch.utils.data.ConcatDataset(datasets_train)
    joint_dataset_val   = torch.utils.data.ConcatDataset(datasets_val)
    joint_dataset_test  = torch.utils.data.ConcatDataset(datasets_test)

    return joint_dataset_train, joint_dataset_val, joint_dataset_test

def build_transform(is_train, args):

    mean = IMAGENET_DEFAULT_MEAN
    std = IMAGENET_DEFAULT_STD


    # train transform
    if is_train == 'train':
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation='bicubic',
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            mean=mean,
            std=std,
        )
        return transform

    # eval transform
    t = []
    if args.input_size <= 224:
        crop_pct = 224 / 256
    else:
        crop_pct = 1.0
    size = int(args.input_size / crop_pct)
    t.append(
        transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
    )
    t.append(transforms.CenterCrop(args.input_size))
    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(mean, std))
    return transforms.Compose(t)
