import os
from typing import List

from clients.base import Client
from clients.semantic_segmentation import create_semantic_segmentation_client
from clients.multi_label_classifcation import create_multi_label_classification_client
from clients.image_classification import create_image_classification_client
from clients.text_classification import create_text_classification_client


def create_exp_config_setup(
    args,
    models_path,
    dataset_path_dict,
    device,
    num_local_epochs,
    local_train_function,
    train_args,
) -> List[Client]:
    """Create clients for experiments"""

    # List of all of the clients
    all_clients = []

    if args.num_MLC_clients > 0:
        # Multi label clasification model
        all_clients.append(
            create_multi_label_classification_client(
                args,
                base_models_path=os.path.join(
                    models_path,
                    "google",
                    "vit-base-patch32-224-in21k",
                ),
                train_ann_files=dataset_path_dict["MLC_clients_train_ann_files"][0],
                test_ann_files=dataset_path_dict["MLC_clients_test_ann_files"][0],
                root_train_image_folder=dataset_path_dict["root_train_image_folder"],
                root_val_image_folder=dataset_path_dict["root_val_image_folder"],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                device=device,
                train_func=local_train_function,
                train_args=train_args,
            )
        )
        all_clients.append(
            create_multi_label_classification_client(
                args,
                base_models_path=os.path.join(
                    models_path,
                    "WinKawaks",
                    "vit-small-patch16-224",
                ),
                train_ann_files=dataset_path_dict["MLC_clients_train_ann_files"][1],
                test_ann_files=dataset_path_dict["MLC_clients_test_ann_files"][1],
                root_train_image_folder=dataset_path_dict["root_train_image_folder"],
                root_val_image_folder=dataset_path_dict["root_val_image_folder"],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                device=device,
                train_func=local_train_function,
                train_args=train_args,
            )
        )

        all_clients.append(
            create_multi_label_classification_client(
                args,
                base_models_path=os.path.join(
                    models_path,
                    "google",
                    "vit-large-patch16-224-in21k",
                ),
                train_ann_files=dataset_path_dict["MLC_clients_train_ann_files"][2],
                test_ann_files=dataset_path_dict["MLC_clients_test_ann_files"][2],
                root_train_image_folder=dataset_path_dict["root_train_image_folder"],
                root_val_image_folder=dataset_path_dict["root_val_image_folder"],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                device=device,
                train_func=local_train_function,
                train_args=train_args,
            )
        )

    if args.num_IC100_clients > 0:
        all_clients.append(
            create_image_classification_client(
                args,
                base_models_path=os.path.join(
                    models_path,
                    "google",
                    "vit-base-patch16-224-in21k",
                ),
                train_files=dataset_path_dict["IC100_clients_train_files"][0],
                test_files=dataset_path_dict["IC100_clients_test_files"][0],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                num_classes=100,
                device=device,
                train_func=local_train_function,
                train_args=train_args,
            )
        )

        all_clients.append(
            create_image_classification_client(
                args,
                base_models_path=os.path.join(
                    models_path,
                    "WinKawaks",
                    "vit-small-patch16-224",
                ),
                train_files=dataset_path_dict["IC100_clients_train_files"][1],
                test_files=dataset_path_dict["IC100_clients_test_files"][1],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                num_classes=100,
                device=device,
                train_func=local_train_function,
                train_args=train_args,
            )
        )

    if args.num_IC10_clients > 0:
        all_clients.append(
            create_image_classification_client(
                args,
                base_models_path=os.path.join(
                    models_path,
                    "WinKawaks",
                    "vit-tiny-patch16-224",
                ),
                train_files=dataset_path_dict["IC10_clients_train_files"][0],
                test_files=dataset_path_dict["IC10_clients_test_files"][0],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                num_classes=10,
                device=device,
                train_func=local_train_function,
                train_args=train_args,
            )
        )

    if args.num_semantic_segmentation_clients > 0:
        # Semantic Segmentation client
        num_local_epochs = 10
        all_clients.append(
            create_semantic_segmentation_client(
                args,
                base_models_path=os.path.join(
                    models_path,
                    "nvidia",
                    "segformer-b0-finetuned-ade-512-512",
                ),
                train_ann_files=dataset_path_dict["semantic_seg_clients_train_ann_files"][0],
                test_ann_files=dataset_path_dict["semantic_seg_clients_test_ann_files"][0],
                root_train_image_folder=dataset_path_dict["root_train_image_folder"],
                root_val_image_folder=dataset_path_dict["root_val_image_folder"],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                device=device,
                train_func=local_train_function,
                train_args=train_args,
                server_lr=args.lr,
            )
        )

        all_clients.append(
            create_semantic_segmentation_client(
                args,
                base_models_path=os.path.join(
                    models_path,
                    "nvidia",
                    "segformer-b1-finetuned-ade-512-512",
                ),
                train_ann_files=dataset_path_dict["semantic_seg_clients_train_ann_files"][1],
                test_ann_files=dataset_path_dict["semantic_seg_clients_test_ann_files"][1],
                root_train_image_folder=dataset_path_dict["root_train_image_folder"],
                root_val_image_folder=dataset_path_dict["root_val_image_folder"],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                device=device,
                train_func=local_train_function,
                train_args=train_args,
                server_lr=args.lr,
            )
        )

        
    if args.num_yahoo_topic_classification_clients > 0:
        num_local_epochs = 10
        all_clients.append(
            create_text_classification_client(
                args,
                num_classes=10,
                base_models_path=os.path.join(
                    models_path,
                    "googlebert",
                    "bert-base-uncased",
                ),
                train_files=dataset_path_dict["yahoo_clients_train_files"][0],
                test_files=dataset_path_dict["yahoo_clients_test_files"][0],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                device=device,
                train_func=lambda: None,
                train_args=train_args,
            )
        )
        
        all_clients.append(
            create_text_classification_client(
                args,
                num_classes=10,
                base_models_path=os.path.join(
                    models_path,
                    "distilbert",
                    "distilbert-base-uncased"
                ),
                train_files=dataset_path_dict["yahoo_clients_train_files"][1],
                test_files=dataset_path_dict["yahoo_clients_test_files"][1],
                public_loader=dataset_path_dict["public_loader"],
                lr=args.lr,
                num_local_epochs=num_local_epochs,
                device=device,
                train_func=lambda: None,
                train_args=train_args,
            )
        )

    # veirfy setup
    assert (
        len(all_clients)
        == args.num_semantic_segmentation_clients
        + args.num_IC10_clients
        + args.num_IC100_clients
        + args.num_MLC_clients
        + args.num_yahoo_topic_classification_clients
    )

    return all_clients
