using Downloads
using ZipFile
using ProjectRoot
using JLD2
using Images, ImageIO, ImageTransformations
using Metalhead
using Flux
using Statistics
using Random
using FileIO
using Colors
using Glob
using ProgressMeter
using Base.Threads: @threads
using OneHotArrays
using Pipe

# --- Configuration ---
# The path to save the Tiny ImageNet dataset.
const TINY_IMAGENET_DIR = @projectroot("data", "tiny-imagenet-200", "tiny-imagenet-200")

# The number of total classes, tasks, and images per class.
const NUM_TOTAL_CLASSES = 21 # for testing
const NUM_TASKS = 3 # for testing
const CLASSES_PER_TASK = 7 # NUM_TOTAL_CLASSES / NUM_TASKS must be integer
const NUM_IMGS_CLASS = 500

# The path to save the extracted features.
# It will look like `data/features_21_3_7_500.jld2`
const FEATURES_SAVE_PATH = @projectroot("data",
    "features_$(NUM_TOTAL_CLASSES)_$(NUM_TASKS)_$(CLASSES_PER_TASK)_$(NUM_IMGS_CLASS).jld2")

# The total number of images we will process and use in the experiment.
const NUM_TOTAL_IMAGES = NUM_TOTAL_CLASSES * NUM_IMGS_CLASS

# The ratio of training images to total images.
const TRAIN_SPLIT_RATIO = 0.9

# The dimension of the features extracted from ResNet-18.
const DIM = 512 # ResNet-18 feature dimension

# The number of images per task.
const NUM_IMGS_PER_TASK = NUM_IMGS_CLASS * CLASSES_PER_TASK

# The number of training and testing images per class.
const NUM_TRAIN_PER_CLASS = NUM_IMGS_CLASS * TRAIN_SPLIT_RATIO |> Int
const NUM_TEST_PER_CLASS = NUM_IMGS_CLASS - NUM_TRAIN_PER_CLASS

# The number of training and testing images per task.
const NUM_TRAIN_PER_TASK = NUM_TRAIN_PER_CLASS * CLASSES_PER_TASK
const NUM_TEST_PER_TASK = NUM_TEST_PER_CLASS * CLASSES_PER_TASK

# ========================= federated settigns =================================
# The number of classes in federated settings.
const NUM_CLS_FEDERATED = 10
# The number of clients in federated settings.
const NUM_CLIENTS = 10

# The path to save the federated dataset.
# It will look like `data/fedarated_10_10.jld2`, where the first `10` means the
# number of classes, the second `10` means the number of clients.
const FEDERATED_SAVE_PATH = @projectroot("data",
    "federated_$(NUM_CLS_FEDERATED)_$(NUM_CLIENTS).jld2")

# The dimension of original images.
const DIM_IMG = 64
# The number of training images in each client.
const NUM_TR_IMGS_CLIENT = NUM_CLS_FEDERATED * NUM_IMGS_CLASS / NUM_CLIENTS |> Int
# The number of training images that belong to the same class in each client.
const NUM_TR_IMGS_CLS_CLIENT = NUM_IMGS_CLASS / NUM_CLIENTS |> Int
# The number of validation images per class.
const NUM_VAL_IMGS_CLS = 50
# The number of validation images in each client.
const NUM_VAL_IMGS_CLIENT = NUM_CLS_FEDERATED * NUM_VAL_IMGS_CLS / NUM_CLIENTS |> Int
# The number of validation images that belong to the same class in each client.
const NUM_VAL_IMGS_CLS_CLIENT = NUM_VAL_IMGS_CLS / NUM_CLIENTS |> Int

"""
    download_tiny_imagenet()

- Input: None
- Output: None

Download the Tiny ImageNet dataset and save it to TINY_IMAGENET_DIR.
"""
function download_tiny_imagenet()
    # setup download URL and paths
    url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    zip_path = @projectroot("tiny-imagenet-200.zip")
    extract_dir = @projectroot("data", "tiny-imagenet-200")

    # download the dataset
    if isdir(extract_dir)
        println("Data exists, done.")
        return
    elseif isfile(zip_path)
        println("Zip file exists, skipping download.")
    else
        println("Downloading Tiny ImageNet dataset...")
        Downloads.download(url, zip_path)
        println("Done.")
    end

    # unzip the dataset.
    println("Unzipping Tiny ImageNet dataset...")
    # read the zip file.
    zip_files = ZipFile.Reader(zip_path)
    # for all files/folders in the zip file.
    for file in zip_files.files
        # get the path.
        dest_path = joinpath(extract_dir, file.name)
        # if it is a folder.
        if endswith(file.name, "/") || endswith(file.name, "\\")
            # create a folder.
            mkpath(dest_path)
        else
            # otherwise, write it to a file.
            mkpath(dirname(dest_path))
            write(dest_path, read(file))
        end
    end
    # close the  zip file.
    close(zip_files)
    println("Done.")

    # remove the zip file
    rm(zip_path; force=true)
end

"""
    tr_data, tr_lbls, val_data, val_lbls, wnids = create_federated_dataset()

- Input: None
- Output:
    - `tr_data::Array{Float32, 5}`: A tensor storing the training dataset.
        `tr_data[:, :, :, j, i]` is the j-th image in client i, which is a 64x64x3
        tensor with three channels.
    - `tr_lbls::Array{Float16, 3}`:  A tensor storing the training labels.
        `tr_lbls[j, i]` is the label of the j-th image in client i, which is
        an integer in `1:NUM_CLS_FEDERATED`.
    - `val_data::Array{Float32, 5}`: A tensor storing the validation dataset.
    - `val_lbls::Array{Float16, 3}`: A tensor storing the validation labels.
    - `wnids::Vector{String}`: A vector of wnids, indicating the classes.

Create federated dataset based on tiny ImageNet dataset.
We pick `NUM_CLS_FEDERATED` classes from the dataset, and for each class, we pick
`NUM_IMGS_CLASS` images.
There will be `NUM_CLIENTS` clients, the images will be averagely distributed to
clients.
In the end, each client will have `NUM_TR_IMGS_CLIENT` training images, where
each class has `NUM_TR_IMGS_CLS_CLIENT` images; each client will have
`NUM_VAL_IMGS_CLIENT` validataion images, where each class has
`NUM_VAL_IMGS_CLS_CLIENT` images.
"""
function create_federated_dataset()
    # Check if data already exists
    if isfile(FEDERATED_SAVE_PATH)
        # Load data from FEDERATED_SAVE_PATH
        data = JLD2.load(FEDERATED_SAVE_PATH)
        return data["tr_data"], # training dataset.
               data["tr_lbls"], # training labels.
               data["val_data"], # validation dataset.
               data["val_lbls"], # validation labels
               data["wnids"] # wnids
    end

    # Ensure the tiny imagenet dataset has been downloaded
    download_tiny_imagenet()

    # Load the dataset and create federated dataset.
    # Get the list of classes and their labels.
    wnids = readdir(joinpath(TINY_IMAGENET_DIR, "train"))[1:NUM_CLS_FEDERATED]

    # The path to the annotations of validataion dataset.
    val_anno_file = joinpath(TINY_IMAGENET_DIR, "val", "val_annotations.txt")
    # The annotations of validataion dataset.
    val_anno = @pipe readlines(val_anno_file) .|> split(_, "\t")
    # The class corresponding to validataion images.
    val_cls = @pipe val_anno .|> _[2]

    # Declare the tensor to store data.
    #= 
                            The training dataset.
    To access the training data in client i, use `tr_data[:, :, :, :, i]`.
    The j-th image in client i will be `tr_data[:, :, :, j, i]`, which is a 
    64x64x3 tensor with three channels.
    =#
    tr_data = zeros(Float32, DIM_IMG, DIM_IMG, 3, NUM_TR_IMGS_CLIENT, NUM_CLIENTS)
    #= 
                            The training labels.
    To access the training label in client i, use `tr_lbls[:, i]`.
    The label of the j-th image in client i will be `tr_data[j, i]`, which is an
    Integer in `1:NUM_CLS_FEDERATED`
    =#
    # The training labels.
    tr_lbls = zeros(Float16, NUM_TR_IMGS_CLIENT, NUM_CLIENTS)

    # The validation dataset.
    val_data = zeros(Float32, DIM_IMG, DIM_IMG, 3, NUM_VAL_IMGS_CLIENT, NUM_CLIENTS)
    # The validation labels.
    val_lbls = zeros(Float16, NUM_VAL_IMGS_CLIENT, NUM_CLIENTS)

    # 
    tr_cls_paths = joinpath.(TINY_IMAGENET_DIR, "train", wnids, "images")
    tr_img_paths = vcat(glob.("*.JPEG", tr_cls_paths)...)

    @threads for j in eachindex(tr_img_paths)
        # Load the image, take its channel and transfer it to Float32.
        img = Float32.(channelview(load(tr_img_paths[j])))

        # Find out which class the current image belongs to.
        # The first NUM_IMGS_CLASS images belong to class 1, the next
        # NUM_IMGS_CLASS images belong to class 2, and so on and so forth.
        cls_img = j / NUM_IMGS_CLASS |> ceil |> Int

        #= 
        Find out which client the current image belongs to.
        Inside each class, the first NUM_TR_IMGS_CLS_CLIENT images belong to
        client 1, the next NUM_TR_IMGS_CLS_CLIENT images belong to client 2,
        and so on and so forth.
        `pos_inside_cls` computes the position of the current image, then divide
        it by NUM_TR_IMGS_CLS_CLIENT to get the client.
        =#
        pos_inside_cls = j - (cls_img - 1) * NUM_IMGS_CLASS
        client_img = pos_inside_cls / NUM_TR_IMGS_CLS_CLIENT |> ceil |> Int
        # The position inside client.
        pos_inside_client = pos_inside_cls - (client_img - 1) * NUM_TR_IMGS_CLS_CLIENT + 
                        (cls_img - 1) * NUM_TR_IMGS_CLS_CLIENT
    
        if ndims(img) == 2
            # If the image is grayscale, convert it to RGB by repeating the channel.
            tr_data[:, :, :, pos_inside_client, client_img] = repeat(img, 1, 1, 3)
        else
            # Otherwise, reorder the dimensions to [width, height, channels].
            tr_data[:, :, :, pos_inside_client, client_img] = permutedims(img, (3, 2, 1))
        end

        # The one-hot encoded label.
        tr_lbls[pos_inside_client, client_img] = cls_img
    end

    for (i, wnid) in enumerate(wnids)
        val_imgs_cls = @pipe val_anno[findall(val_cls .== wnid)] .|> _[1]
        val_imgs_cls_paths = joinpath.(TINY_IMAGENET_DIR, "val", "images", val_imgs_cls)

        @threads for j in eachindex(val_imgs_cls_paths)
            img = Float32.(channelview(load(val_imgs_cls_paths[j])))
            client_img = j / NUM_VAL_IMGS_CLS_CLIENT |> ceil |> Int
            pos_inside_client = j - (client_img - 1) * NUM_VAL_IMGS_CLS_CLIENT +
                        (i - 1) * NUM_VAL_IMGS_CLS_CLIENT

            if ndims(img) == 2
                # If the image is grayscale, convert it to RGB by repeating the channel.
                val_data[:, :, :, pos_inside_client, client_img] = repeat(img, 1, 1, 3)
            else
                # Otherwise, reorder the dimensions to [width, height, channels].
                val_data[:, :, :, pos_inside_client, client_img] = permutedims(img, (3, 2, 1))
            end

            # The one-hot encoded label.
            val_lbls[pos_inside_client, client_img] = i
        end
    end

    # save the features, labels to a JLD2 file
    JLD2.save(FEDERATED_SAVE_PATH, Dict(
        "tr_data" => tr_data,
        "tr_lbls" => tr_lbls,
        "val_data" => val_data,
        "val_lbls" => val_lbls,
        "wnids" => wnids
    ))

    return tr_data, tr_lbls, val_data, val_lbls, wnids
end

"""
    img_array::Array{Float32, 4} = preprocess_image(path::String)

- Input:
    - `path`: The path to the image file.
- Output:
    - `img_array`: A 4D array of type `Float32` with dimensions (224, 224, 3, 1).

Preprocess the image at the given path for ResNet-18.
It resizes the image to 224x224 pixels, normalizes it, and adds a batch dimension.
"""
function preprocess_image(path::String)
    # load the image from the given path.
    # `load` function will return N0f8 type, which is already normalized.
    img = load(path)

    # resize the image to 224x224 pixels since ResNet-18 accepts 224x224 images
    img_resized = imresize(img, (224, 224))

    # split the image into channels and convert to Float32.
    img_array = Float32.(channelview(img_resized))

    if ndims(img_array) == 2
        # if the image is grayscale, convert it to RGB by repeating the channel
        img_array = repeat(img_array, 1, 1, 3)
    else
        # reorder the dimensions to [width, height, channels]
        img_array = permutedims(img_array, (3, 2, 1))
    end

    # add a batch dimension
    img_array = reshape(img_array, size(img_array)..., 1)

    return img_array
end

"""
    features, labels, wnids = extract_features()

- Input: None
- Output:
    - `features`: A 2D array of shape (512, NUM_TOTAL_IMAGES) containing the
                  extracted features from ResNet-18.
    - `labels`:   A 1D array of shape (NUM_TOTAL_IMAGES,) containing the class
                  labels for each image.
    - `wnids`:    A vector of class names (wnids) corresponding to the labels.

Preprocess the Tiny ImageNet dataset by extracting features using a pretrained
ResNet-18 model. The features are saved to a JLD2 file for later use.
"""
function extract_features()
    # check if features already exist
    if isfile(FEATURES_SAVE_PATH)
        println("Loading pre-extracted features from $(FEATURES_SAVE_PATH)...")
        data = JLD2.load(FEATURES_SAVE_PATH)
        return data["features"], # features
               data["labels"], # labels
               data["wnids"] # wnids
    end

    # ensure the dataset has been downloaded
    download_tiny_imagenet()

    # load the dataset and extract features
    println("Preprocessing Tiny ImageNet...")

    # load pretrained ResNet-18 model.
    model = ResNet(18; pretrain=true)
    # remove the last classification layer to get features
    feature_extractor = Chain(model.layers[1:end-1]...)

    # get the list of classes and their labels
    wnids = readdir(joinpath(TINY_IMAGENET_DIR, "train"))[1:NUM_TOTAL_CLASSES]

    # declare the tensor for storing features.
    features = zeros(Float32, DIM, NUM_TOTAL_IMAGES)
    labels = zeros(Float16, NUM_TOTAL_IMAGES)

    # a temp index for progress tracking
    temp_idx = 1

    for (i, wnid) in enumerate(wnids)
        class_path = joinpath(TINY_IMAGENET_DIR, "train", wnid, "images")
        img_paths = glob("*.JPEG", class_path)[1:NUM_IMGS_CLASS]

        # process every image
        @showprogress desc="[$temp_idx/$NUM_TOTAL_CLASSES] Images in $wnid" @threads for j in eachindex(img_paths)
            img = img_paths[j]
            img_pre = preprocess_image(img)
            features[:, (i - 1) * NUM_IMGS_CLASS + j] = reshape(mean(
                            feature_extractor(img_pre), dims=(1, 2)), :)
            labels[(i - 1) * NUM_IMGS_CLASS + j] = i
        end

        # update progress
        temp_idx += 1
    end

    println("Saving extracted features to $(FEATURES_SAVE_PATH)...")

    # save the features, labels to a JLD2 file
    JLD2.save(FEATURES_SAVE_PATH, Dict(
        "features" => features,
        "labels" => labels,
        "wnids" => wnids
    ))

    println("Done.")

    return features, labels, wnids
end

"""
    train_features_tasks, test_features_tasks, train_labels_tasks,
                    test_labels_tasks, cls_perm = create_meta_learning_tasks()
                
- Input: None
- Output:
    - `train_features_tasks`: A 3D array of training features for each task.
    - `test_features_tasks`:  A 3D array of testing features for each task.
    - `train_labels_tasks`:   A 3D array of training labels (one-hot encoded)
                              for each task.
    - `test_labels_tasks`:    A 3D array of testing labels (one-hot encoded) for
                              each task.
    - `cls_perm`:             A matrix containing the class indices for each task,
                              where each column corresponds to a task and each
                              row corresponds to a class.

Create meta-learning tasks by splitting the features and labels into training
and testing sets for each class, and then grouping them into tasks.
"""
function create_meta_learning_tasks()
    features, labels, wnids = extract_features()

    # Arrays for saving features/labels.
    train_features_cls = zeros(Float32, DIM, NUM_TRAIN_PER_CLASS, NUM_TOTAL_CLASSES)
    test_features_cls = zeros(Float32, DIM, NUM_TEST_PER_CLASS, NUM_TOTAL_CLASSES)
    train_labels_cls = zeros(Int, NUM_TRAIN_PER_CLASS, NUM_TOTAL_CLASSES)
    test_labels_cls = zeros(Int, NUM_TEST_PER_CLASS, NUM_TOTAL_CLASSES)

    # for each class, split the features into training and testing set.
    @showprogress desc="Classes" @threads for i = 1:NUM_TOTAL_CLASSES
        featuresᵢ = features[:, labels .== i]
        labelsᵢ = labels[labels .== i]

        # randomly permute the indices
        idx = randperm(NUM_IMGS_CLASS)

        train_features_cls[:, :, i] = featuresᵢ[:, idx[1:NUM_TRAIN_PER_CLASS]]
        test_features_cls[:, :, i] = featuresᵢ[:, idx[NUM_TRAIN_PER_CLASS+1:end]]
        train_labels_cls[:, i] = labelsᵢ[idx[1:NUM_TRAIN_PER_CLASS]]
        test_labels_cls[:, i] = labelsᵢ[idx[NUM_TRAIN_PER_CLASS+1:end]]
    end

    # assign classes to different tasks.
    # Arrays for saving features for different tasks / labels.
    train_features_tasks = zeros(Float32, DIM, NUM_TRAIN_PER_TASK, NUM_TASKS)
    test_features_tasks = zeros(Float32, DIM, NUM_TEST_PER_TASK, NUM_TASKS)
    # for labels, we use one-hot encoding.
    train_labels_tasks = zeros(Int, CLASSES_PER_TASK, NUM_TRAIN_PER_TASK, NUM_TASKS)
    test_labels_tasks = zeros(Int, CLASSES_PER_TASK, NUM_TEST_PER_TASK, NUM_TASKS)

    # randomly permute the classes.
    cls_perm = reshape(randperm(NUM_TOTAL_CLASSES), :, NUM_TASKS)

    @showprogress @threads for t in 1:NUM_TASKS
        cls_idx = cls_perm[:, t]
        local_cls_train_idx = reshape(repeat(cls_idx, 1, NUM_TRAIN_PER_CLASS)', :)
        local_cls_test_idx = reshape(repeat(cls_idx, 1, NUM_TEST_PER_CLASS)', :)
        train_idx = randperm(NUM_TRAIN_PER_TASK)
        test_idx = randperm(NUM_TEST_PER_TASK)

        train_features_tasks[:, train_idx, t] = reshape(train_features_cls[:, :, cls_idx], DIM, :)
        test_features_tasks[:, test_idx, t] = reshape(test_features_cls[:, :, cls_idx], DIM, :)

        train_labels_tasks[:, train_idx, t] = Int.(onehotbatch(local_cls_train_idx, cls_idx))
        test_labels_tasks[:, test_idx, t] = Int.(onehotbatch(local_cls_test_idx, cls_idx))
    end

    return train_features_tasks, test_features_tasks, train_labels_tasks, test_labels_tasks, cls_perm
end
