import logging
import torch.optim.lr_scheduler
from src.algo.center_server import FedAvgCenterServer
from src.models import Model
from torch.utils.data import DataLoader
from src.utils import create_datasets
from src.algo import Algo

log = logging.getLogger(__name__)


class Centralized(Algo):
    def __init__(self, model_info, params, device: str, dataset: str,
                 output_suffix: str, savedir: str, writer=None):
        assert params.loss.type == "crossentropy", "Loss function for centralized algorithm must be crossentropy"
        super().__init__(model_info, params, device, dataset, output_suffix, savedir, writer)
        self._batch_size = params.batch_size

        # dataset_getter, dataset_class, dataset_num_classes = get_dataset(dataset.getter_fn), dataset.args.num_classes
        training_set, test_dataset, _, self._excluded_from_test = create_datasets(self._dataset, 1)
        test_loader = DataLoader(test_dataset, num_workers=6, batch_size=self._batch_size, shuffle=False)

        model = Model(model_info, self._dataset.args.num_classes)

        self._train_loader = DataLoader(training_set[0], num_workers=6, batch_size=self._batch_size, shuffle=True)
        self._center_server = FedAvgCenterServer(model, test_loader, device, 
                                                 analyzer=self._analyzer.module_analyzer('server'),
                                                 **params.center_server.args)
        self._scheduler = None
        self._optim = None

    def _fit(self, iterations):
        model = self._center_server.model
        model.to(self._device)
        self._loss_fn.to(self._device)
        model.train()
        self._optim = self._optimizer(model.parameters(), **self._optimizer_args)
        self._scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self._optim, iterations)

        self._analyzer.module_analyzer('server')('validation', server=self._center_server, loss_fn=self._loss_fn,
                                                 s_round=self._iteration)
        while self._next_iter(iterations):
            self.train_step()
            self._analyzer.module_analyzer('server')('validation', server=self._center_server, loss_fn=self._loss_fn,
                                                     s_round=self._iteration)
        log.info("Training completed")

    def train_step(self):
        model = self._center_server.model
        Algo.train(model, self._device, self._optim, self._loss_fn, self._train_loader)
        self._scheduler.step()
