import numpy as np
import torch
import wandb

from utils import unravel_model_params
from logger import getLogger
from client import ServerClientBase, get_dataloader_kwargs
from utils import unravel_model_params



class Server(ServerClientBase):
    def __init__(self, server_model, clients, criterion, args):
        self.args = args
        self.criterion = criterion
        self.server_model = server_model
        self.clients = clients
        self.logger = getLogger(__name__)

    def test(self, testset, opt_loss=0.0):
        
        self.server_model.eval()
        testloader = torch.utils.data.DataLoader(testset, batch_size=self.args.test_batch_size, shuffle=False, **get_dataloader_kwargs(self.args.device))

        test_loss, test_accuracy = self.evaluate_testset(testloader, self.server_model)
        test_loss = test_loss - opt_loss

        self.logger.info(f'Test loss: {test_loss:.6f}, Test accuracy: {test_accuracy:.4f}')
        wandb.log({f'test_acc': test_accuracy, f'test_loss': test_loss}, commit=True)
        
        self.server_model.train()

    def subsample(self, clients):
        if self.args.num_subsample == -1:
            return clients
        else:
            client_idxs = np.random.choice(len(clients), self.args.num_subsample, replace=False)
            clients = [clients[i] for i in client_idxs]
            return clients

    def evaluate_testset(self, test_loader, model):
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(self.args.device), target.to(self.args.device)
                output = model(data)
                test_loss += self.loss(model, output, target).item()*len(data) # to counteract reduction='mean'
                pred = self.pred(output)
                correct += pred.eq(target.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        test_accuracy = correct / len(test_loader.dataset)
        return test_loss, test_accuracy

    def train(self, on_step_end=lambda step: None):
        raise NotImplementedError


class InexactFedDRServer(Server):
    def __init__(self, server_model, clients, criterion, args):
        super().__init__(server_model, clients, criterion, args)

    def train(self, on_step_end=lambda step: None):
        assert self.args.num_subsample == -1, "Subsampling not supported for FedDR"

        # Initialize
        self.server_model.train()
        num_clients = len(self.clients)
        alpha = 0
        phat = 0.0

        # Statistics
        total_client_updates = 0
        total_comm = 0

        for server_steps in range(self.args.server_num_steps):
            num_client_refinments = 0

            # Create client steppers
            client_steppers = [None for _ in range(num_clients)]
            for i, client in enumerate(self.clients):
                client_steppers[i] = client.step(alpha, phat)
            total_comm += 1 # server -> client

            err_condition_met = False
            while not err_condition_met:
                # Compute gradient on each client
                loss = 0.0
                phat = 0.0
                xbar_sum = 0.0
                xbar_norm_sq_sum = 0.0
                sgammagrad_sum = 0.0
                innerprod_sum = 0.0
                sgammagrad_norm_sq_sum = 0.0
                epsilon_sum = 0.0
                for i, client in enumerate(self.clients):
                    client_loss, xbar_i, xbar_grad_i, s_i = next(client_steppers[i])
                    loss += client_loss
                    
                    xbar_sum += xbar_i
                    xbar_norm_sq_sum += torch.norm(xbar_i)**2 
                    sgammagrad = s_i - self.args.client_gamma * xbar_grad_i
                    sgammagrad_norm_sq_sum += torch.norm(sgammagrad)**2
                    sgammagrad_sum += sgammagrad
                    innerprod_sum += xbar_i.dot(-sgammagrad)
                    epsilon_sum += torch.norm(s_i - self.args.client_gamma * xbar_grad_i - xbar_i)**2
                    phat += xbar_i - self.args.client_gamma * xbar_grad_i
                
                phat = phat/num_clients
                phat_norm_sq = torch.norm(phat)**2
                innerprod1 = xbar_sum.dot(phat)
                innerprod2 = phat.dot(sgammagrad_sum)
                xi = xbar_norm_sq_sum - 2 * innerprod1 + num_clients * phat_norm_sq
                mu = -innerprod2 - innerprod_sum - innerprod1 + num_clients * phat_norm_sq
                zeta = sgammagrad_norm_sq_sum - 2 * innerprod2 + num_clients * phat_norm_sq
                zeta *= 1/(self.args.client_gamma**2)

                # Check if error condition is met
                sigma_sq = 0.99
                err_condition_met = epsilon_sum <= sigma_sq * max(xi, zeta)
                
                if self.args.no_refinement:
                    err_condition_met = True

                # update statistics
                total_client_updates += self.args.client_num_epochs
                total_comm += 3 # client -> server
                num_client_refinments += 1

                # Continue refinement if condition is not met
                if err_condition_met:
                    # Update the server model and stepsize
                    unravel_model_params(self.server_model, phat)
                    alpha = mu/xi

                    if server_steps % self.args.log_interval == 0:
                        self.logger.info(f"Step {server_steps} ({num_client_refinments} refinements): loss {loss}")
                    wandb.log({
                        'alpha': alpha.item(),
                        'train_loss': loss, 
                        'num_client_refinments': num_client_refinments,
                        'total_client_updates': total_client_updates,
                        'total_comm': total_comm,
                    }, commit=False)

                    on_step_end(server_steps, server_steps)
                    break
                else:
                    print(f"refines {num_client_refinments}: sigma^2={sigma_sq},eps={epsilon_sum},xi={xi},zeta={zeta}")
                    if num_client_refinments > 100:
                        raise RuntimeError("The number of refinements exceeded the maximum allowance of 100.")
                    continue

        return server_steps, server_steps
