import torchvision
import torch
import os
import ray

import sys
sys.path.append('..') 

from utils import fix_all_seed
from data import transform_emnist, divide_dataset_indices, subset_by_indices
from msppm import MSPPM, sample_workers_iid
from models import CNN21, CNN22
from prox import CohortOptimizer, CohortOptimizerLBFGS, CohortOptimizerFedAvg

FOLDER = "logs/"

if not os.path.exists(FOLDER):
    os.makedirs(FOLDER)

# @ray.remote
def launch_experiment(minibatch_size, local_epochs, optimizer, optimizer_hparams, cohort_optimizer, folder):

    OPTIMIZERS = {
        "Adam": torch.optim.Adam,
        "LBFGS": torch.optim.LBFGS,
    }

    cohort_optimizerS = {
        "Base": CohortOptimizer,
        "LBFGS": CohortOptimizerLBFGS,
        "FedAvg": CohortOptimizerFedAvg,
    }

    params = {
        "seed": 42,
        "minibatch_size": minibatch_size,
        "batch_size": 1024,
        "epochs": 200,
        "local_epochs": local_epochs,
        "workers_count": 1000,
        "log_every": 5,
    }

    cohort_optimizer_HPARAMS = {
        "optimizer": OPTIMIZERS[optimizer],
        "optimizer_hparams": optimizer_hparams,
        "batch_size": params["batch_size"],
    }

    params['cohort_optimizer_hparams'] = cohort_optimizer_HPARAMS

    # Building the filename string
    params['filename'] = folder
    params['filename'] += '-'.join([f'{key}:{params[key]}' for key in params.keys() if key != 'cohort_optimizer_hparams' and key != 'filename'])
    params['filename'] += '-' + f'optimizer:{optimizer}' + '-' + '-'.join([f"{key}:{cohort_optimizer_HPARAMS['optimizer_hparams'][key]}"  for key in cohort_optimizer_HPARAMS["optimizer_hparams"].keys()])
    params['filename'] += '-' + f'cohort_optimizer:{cohort_optimizer}'
    params['filename'] += '.csv'

    # Updating the params dictionary
    params.update({
        "device": "cuda",
        "plot": True,
        "tqdm": True,
        "sampler": sample_workers_iid,
        "loss": torch.nn.CrossEntropyLoss(reduction='sum'),
        "cohort_optimizer": cohort_optimizerS[cohort_optimizer],
    })
    dataset_train = torchvision.datasets.EMNIST(root='./data', split='balanced', train=True, download=True, transform=transform_emnist)
    dataset_test = torchvision.datasets.EMNIST(root='./data', split='balanced', train=False, download=True, transform=transform_emnist)
    NUM_CLASSES_EMNIST = len(dataset_train.classes)
    fix_all_seed(params["seed"])
    workers_data_indices = divide_dataset_indices(dataset_train, params["workers_count"], num_chunks=2)
    params["model"] = CNN22(num_classes=NUM_CLASSES_EMNIST).to(params["device"])

    algorithm = MSPPM(workers_count=params["workers_count"],
                    batch_size=params["minibatch_size"],
                    epochs=params["epochs"],
                    seed=params["seed"],
                    sampler=params["sampler"],
                    model=params["model"],
                    loss=params["loss"],
                    cohort_optimizer=params["cohort_optimizer"],
                    cohort_optimizer_hparams=params["cohort_optimizer_hparams"],
                    local_epochs=params["local_epochs"],
                    device=params["device"],
                    plot=params["plot"],
                    tqdm=params["tqdm"],
                    log_every=params["log_every"],
                    logs_filename=params["filename"]
                    )

    algorithm.train(dataset_train, workers_data_indices, dataset_test)


ray.init()
minibatch_sizes = [10, 100]
local_epochs = [5, 10, 50]
optimizers = [
    ("Adam", {"lr": 0.001}, "Base"),
    ("Adam", {"lr": 0.001}, "FedAvg"),
    ("LBFGS", {"lr": 0.1, "max_iter": 1}, "LBFGS"),
    ("LBFGS", {"lr": 0.1, "max_iter": 5}, "LBFGS"),
    ("LBFGS", {"lr": 0.1, "max_iter": 10}, "LBFGS"),
    ("LBFGS", {"lr": 0.1, "max_iter": 20}, "LBFGS"),
]
for optimizer in optimizers:
    for local_epoch in local_epochs:
        for minibatch_size in minibatch_sizes:
            if local_epoch == 1 and minibatch_size == 1:
                continue
            launch_experiment(minibatch_size, local_epoch, optimizer[0], optimizer[1], optimizer[2], FOLDER)

ray.shutdown()