from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from copy import deepcopy
import numpy as np
import torch

class CohortOptimizer:
    def __init__(self, hparams, device="cpu", process_count=1):
        self.hparams = hparams
        self.device = device
        self.process_count = process_count
        
        if "gamma" in hparams:
            self.gamma = hparams["gamma"]
        else:
            self.gamma = np.inf

    def __call__(self, model, data_list, loss, epochs):
        raise NotImplementedError

    def _add_proximal_term(self, model, initial_model_parameters):
        if self.gamma == np.inf or self.gamma is None:
            return model
        with torch.no_grad():
            for p, initial_p in zip(model.parameters(), initial_model_parameters):
                if p.requires_grad:
                    if p.grad is None:
                        p.grad = (1 / self.gamma) * (p.data.clone() - initial_p.data.clone())
                    else:
                        p.grad += (1 / self.gamma) * (p.data.clone() - initial_p.data.clone())
        return model

class CohortOptimizerProx(CohortOptimizer):
    def __init__(self, hparams, device="cpu", process_count=1):
        super().__init__(hparams, device=device, process_count=process_count)
        self.worker_optimizer = hparams["worker_optimizer"]
        self.worker_optimizer_hparams = hparams["worker_optimizer_hparams"]
        self.worker_optimizer_steps = hparams["worker_optimizer_steps"] if "worker_optimizer_steps" in hparams else 1
        self.server_optimizer = hparams["server_optimizer"]
        self.server_optimizer_hparams = hparams["server_optimizer_hparams"]
        if "batch_size" in hparams:
            self.batch_size = hparams["batch_size"]
        else:
            self.batch_size = np.inf

        torch.cuda.empty_cache()

    def _local_optimisation(self, model, data, loss, epochs):
        model.train()
        optimizer = self.worker_optimizer(model.parameters(), **self.worker_optimizer_hparams)
        batch_size = len(data) if self.batch_size == np.inf else self.batch_size
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
        for _ in range(epochs):
            for inputs, targets in dataloader:
                optimizer.zero_grad()
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                output = model(inputs)
                loss_value = loss(output, targets)
                loss_value.backward()
                model = self._add_proximal_term(model, self.initial_model_parameters)
                optimizer.step()
                del inputs, targets, output, loss_value
        return model

    def __call__(self, model, data_list, loss, epochs):
        model.train()
        minibatch_size = len(data_list)
        self.initial_model_parameters = [param.clone().detach() for param in model.parameters()]
        server_optimizer = self.server_optimizer(model.parameters(), **self.server_optimizer_hparams)
        for _ in range(epochs):
            with torch.no_grad():
                for param in model.parameters():
                    if param.requires_grad:
                        param.grad = param.data.clone()
            for data in data_list:
                local_model = deepcopy(model)
                local_model.zero_grad()
                local_model = self._local_optimisation(local_model, data, loss, self.worker_optimizer_steps)
                with torch.no_grad():
                    for model_param, local_model_param in zip(model.parameters(), local_model.parameters()):
                        if model_param.requires_grad:
                            model_param.grad -= local_model_param.data.clone() / minibatch_size
                del local_model
            server_optimizer.step()
            server_optimizer.zero_grad()
        return model

class MimeLite(CohortOptimizer):
    def __init__(self, hparams, device="cpu", process_count=1):
        super().__init__(hparams, device=device, process_count=process_count)
        self.worker_optimizer = hparams["worker_optimizer"]
        self.worker_optimizer_hparams = hparams["worker_optimizer_hparams"]
        self.worker_optimizer_steps = hparams["worker_optimizer_steps"] if "worker_optimizer_steps" in hparams else 1
        if "batch_size" in hparams:
            self.batch_size = hparams["batch_size"]
        else:
            self.batch_size = np.inf

        torch.cuda.empty_cache()

    def _calculate_grad(self, model, data, loss):
        model.train()
        model_grad = deepcopy(model)
        model_grad.zero_grad(set_to_none=False)
        batch_size = len(data) if self.batch_size == np.inf else self.batch_size
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
        for inputs, targets in dataloader:
            model.zero_grad()
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            output = model(inputs)
            loss_value = loss(output, targets)
            loss_value.backward()
            with torch.no_grad():
                for p, model_p in zip(model_grad.parameters(), model.parameters()):
                    if p.requires_grad:
                        if p.grad is None:
                            p.grad = model_p.grad.clone() * inputs.size(0)
                        else:
                            p.grad += model_p.grad.clone() * inputs.size(0)
            del inputs, targets, output, loss_value
        with torch.no_grad():
            for p in model_grad.parameters():
                if p.requires_grad:
                    p.grad /= len(data)
        return model_grad

    def _local_optimisation(self, model, data, loss, epochs, optimizer_state):
        model.train()
        model_grad = self._calculate_grad(model, data, loss)
        optimizer = self.worker_optimizer(model.parameters(), **self.worker_optimizer_hparams)
        optimizer.load_state_dict(optimizer_state)
        batch_size = len(data) if self.batch_size == np.inf else self.batch_size
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
        for _ in range(epochs):
            for inputs, targets in dataloader:
                optimizer.zero_grad()
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                output = model(inputs)
                loss_value = loss(output, targets)
                loss_value.backward()
                model = self._add_proximal_term(model, self.initial_model_parameters)
                optimizer.step()
                del inputs, targets, output, loss_value
        return model, model_grad
    
    def __call__(self, model, data_list, loss, epochs):
        model.train()
        minibatch_size = len(data_list)
        self.initial_model_parameters = [param.clone().detach() for param in model.parameters()]
        optimizer = self.worker_optimizer(model.parameters(), **self.worker_optimizer_hparams)
        for _ in range(epochs):
            optimizer.zero_grad()
            new_model = deepcopy(model)
            for param in new_model.parameters():
                if param.requires_grad:
                    param.data *= 0
            for data in data_list:
                local_model = deepcopy(model)
                local_model, local_model_grad = self._local_optimisation(local_model, data, loss, self.worker_optimizer_steps, optimizer.state_dict())
                with torch.no_grad():
                    for param, new_param in zip(local_model.parameters(), new_model.parameters()):
                        if new_param.requires_grad:
                            new_param.data += param.data.clone() / minibatch_size
                    for param, new_param in zip(local_model_grad.parameters(), model.parameters()):
                        if new_param.requires_grad:
                            if new_param.grad is None:
                                new_param.grad = param.grad.clone() / minibatch_size
                            else:
                                new_param.grad += param.grad.clone() / minibatch_size
                del local_model, local_model_grad
            optimizer.step()
            with torch.no_grad():
                for param, new_param in zip(model.parameters(), new_model.parameters()):
                    if param.requires_grad:
                        param.data = new_param.data.clone()
            del new_model
        return model


def calculate_grad_norm(model):
    total_norms = []
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norms.append(param_norm.item() ** 2)
    total_norm = np.sqrt(np.sum(total_norms))
    return total_norm