import copy

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import federated
import utils
from .base import *


class SISAController(FederatedController):
    def __init__(self, 
                 server_model: nn.Module, 
                 client_loaders: dict[int, dict[str, DataLoader]],
                 eval_loaders: dict[str, DataLoader] = None,
                 config=None,
                 ):
        super(SISAController, self).__init__(server_model,
                                             client_loaders,
                                             eval_loaders,
                                             config)

        self.client_models = {id: copy.deepcopy(self.init_server_model) for id in client_loaders.keys()}

    
    def train(self, num_rounds):
        for rnd in tqdm(range(num_rounds), desc="train_sisa"):
            for id in self.client_loaders.keys():
                model = self.client_models[id] 
                dataloader = self.client_loaders[id]["train"]
                federated.train_model(model,
                                      dataloader,
                                      config=self.config.local,
                                      verbose=True,
                                      device=self.device)
            metrics = self.evaluate()
            if (rnd + 1) % self.config.federated.log_frequency == 0 or (rnd + 1) == num_rounds:
                print("round:", rnd)
                for k, v in metrics.items():
                    print("{}: {:.4f}".format(k, v))
    
    def leave(self, unlearned_client):
        super().leave(unlearned_client)
        
        self.client_loaders.pop(unlearned_client)
        self.client_models.pop(unlearned_client)
        self.evaluate()
    
    def join(self, client_id, dataloader, attach_to=None):
        super().join(client_id, dataloader, attach_to)

        self.client_models[client_id] = copy.deepcopy(self.init_server_model)
        self.evaluate()
    
    def evaluate(self):
        super().prepare_evaluation()

        client_weighted = {}
        for id in self.client_loaders.keys():
            client_weighted[id] = len(self.client_loaders[id]["train"].dataset)
        fed_samples = sum(client_weighted.values())
        client_weighted = {k: v / fed_samples for k, v in client_weighted.items()}

        train_acc = self.check_sisa_accuracy(self.eval_loaders["train"], 
                                             client_weighted)
        test_acc = self.check_sisa_accuracy(self.eval_loaders["test"],
                                            client_weighted)
        metrics = {"train_acc": train_acc,
                   "test_acc": test_acc}
        utils.log_stats(self.log_path, **metrics)
        return metrics
    
    def check_sisa_accuracy(self, dataloader, pred_weights):
        num_correct = 0
        num_samples = 0
        with torch.no_grad():
            for x, y in dataloader:
                x = x.to(device=self.device)
                y = y.to(device=self.device)
                scores = None
                # ensemble predictions from all (retained) clients
                for id in self.client_loaders.keys():
                    model = self.client_models[id]
                    weight = pred_weights[id]
                    model.to(self.device)
                    model.eval()
                    if scores is None:
                        scores = weight * model(x)
                    else:
                        scores += weight * model(x)
                    model.to("cpu")
                _, predictions = scores.max(1)
                num_correct += (predictions == y).sum()
                num_samples += predictions.size(0)
        return float(num_correct)/float(num_samples)*100