from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from copy import deepcopy
import numpy as np
import ray
import torch

class CohortOptimizer:
    def __init__(self, hparams, device="cpu", process_count=1):
        self.hparams = hparams
        self.device = device
        self.process_count = process_count
        
        if "gamma" in hparams:
            self.gamma = hparams["gamma"]
        else:
            self.gamma = np.inf

    def __call__(self, model, data_list, loss, epochs):
        raise NotImplementedError

    def _add_proximal_term(self, model, initial_model_parameters):
        if self.gamma == np.inf or self.gamma is None:
            return model
        with torch.no_grad():
            for p, initial_p in zip(model.parameters(), initial_model_parameters):
                if p.requires_grad:
                    if p.grad is None:
                        p.grad = (1 / self.gamma) * (p.data.clone() - initial_p.data.clone())
                    else:
                        p.grad += (1 / self.gamma) * (p.data.clone() - initial_p.data.clone())
        return model

class CohortOptimizerProx(CohortOptimizer):
    def __init__(self, hparams, device="cpu", process_count=1):
        super().__init__(hparams, device=device, process_count=process_count)
        self.worker_optimizer = hparams["worker_optimizer"]
        self.worker_optimizer_hparams = hparams["worker_optimizer_hparams"]
        self.worker_optimizer_steps = hparams["worker_optimizer_steps"] if "worker_optimizer_steps" in hparams else 1
        self.server_optimizer = hparams["server_optimizer"]
        self.server_optimizer_hparams = hparams["server_optimizer_hparams"]
        if "batch_size" in hparams:
            self.batch_size = hparams["batch_size"]
        else:
            self.batch_size = np.inf

        self.ray = hparams["ray"]
        # if ray is unitianalized
        if self.ray and not ray.is_initialized():
            ray.init()

        torch.cuda.empty_cache()
        torch.backends.cuda.cufft_plan_cache.clear()

    def _local_optimisation(self, model, data, loss, epochs):
        model.train()
        optimizer = self.worker_optimizer(model.parameters(), **self.worker_optimizer_hparams)
        batch_size = len(data) if self.batch_size == np.inf else self.batch_size
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
        for _ in range(epochs):
            for inputs, targets in dataloader:
                optimizer.zero_grad()
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                output = model(inputs)
                loss_value = loss(output, targets)
                loss_value.backward()
                model = self._add_proximal_term(model, self.initial_model_parameters)
                optimizer.step()
                del inputs, targets, output, loss_value
        return model

    # @ray.remote
    # def _local_optimisation_ray(self, model, data, loss, epochs):
    #     local_model = model.to(self.device)
    #     optimizer = self.worker_optimizer(local_model.parameters(), **self.worker_optimizer_hparams)
    #     batch_size = len(data) if self.batch_size == np.inf else self.batch_size
    #     dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
    #     for _ in range(epochs):
    #         for inputs, targets in dataloader:
    #             optimizer.zero_grad()
    #             inputs, targets = inputs.to(self.device), targets.to(self.device)
    #             output = local_model(inputs)
    #             loss_value = loss(output, targets)
    #             loss_value.backward()
    #             optimizer.step()
    #     return local_model.state_dict()

    def __call__(self, model, data_list, loss, epochs):
        model.train()
        minibatch_size = len(data_list)
        self.initial_model_parameters = [param.clone().detach() for param in model.parameters()]
        server_optimizer = self.server_optimizer(model.parameters(), **self.server_optimizer_hparams)
        for _ in range(epochs):
            with torch.no_grad():
                for param in model.parameters():
                    if param.requires_grad:
                        param.grad = param.data.clone()
            if self.ray:
                raise NotImplementedError
                # futures = []
                # for data in data_list:
                #     local_model = deepcopy(model)
                #     futures += [self._local_optimisation_ray.remote(model=local_model, data=data, loss=loss, epochs=self.optimizer_steps)]
                #     del local_model
                # results = ray.get(futures)
                # for state_dict in results:
                #     local_model = deepcopy(model)
                #     local_model.load_state_dict(state_dict)
                #     for param, new_param in zip(local_model.parameters(), new_model.parameters()):
                #         if new_param.requires_grad:
                #             new_param.data.add_(param.data)
            else:
                for data in data_list:
                    local_model = deepcopy(model)
                    local_model.zero_grad()
                    local_model = self._local_optimisation(local_model, data, loss, self.worker_optimizer_steps)
                    with torch.no_grad():
                        for model_param, local_model_param in zip(model.parameters(), local_model.parameters()):
                            if model_param.requires_grad:
                                model_param.grad -= local_model_param.data.clone() / minibatch_size
                    del local_model
            server_optimizer.step()
            server_optimizer.zero_grad()
        return model

class MimeLite(CohortOptimizer):
    def __init__(self, hparams, device="cpu", process_count=1):
        super().__init__(hparams, device=device, process_count=process_count)
        self.worker_optimizer = hparams["worker_optimizer"]
        self.worker_optimizer_hparams = hparams["worker_optimizer_hparams"]
        self.worker_optimizer_steps = hparams["worker_optimizer_steps"] if "worker_optimizer_steps" in hparams else 1
        if "batch_size" in hparams:
            self.batch_size = hparams["batch_size"]
        else:
            self.batch_size = np.inf

        self.ray = hparams["ray"]
        # if ray is unitianalized
        if self.ray and not ray.is_initialized():
            ray.init()

        torch.cuda.empty_cache()
        torch.backends.cuda.cufft_plan_cache.clear()


    def _calculate_grad(self, model, data, loss):
        model.train()
        model_grad = deepcopy(model)
        model_grad.zero_grad(set_to_none=False)
        batch_size = len(data) if self.batch_size == np.inf else self.batch_size
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
        for inputs, targets in dataloader:
            model.zero_grad()
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            output = model(inputs)
            loss_value = loss(output, targets)
            loss_value.backward()
            with torch.no_grad():
                for p, model_p in zip(model_grad.parameters(), model.parameters()):
                    if p.requires_grad:
                        if p.grad is None:
                            p.grad = model_p.grad.clone() * inputs.size(0)
                        else:
                            p.grad += model_p.grad.clone() * inputs.size(0)
            del inputs, targets, output, loss_value
        with torch.no_grad():
            for p in model_grad.parameters():
                if p.requires_grad:
                    p.grad /= len(data)
        return model_grad

    def _local_optimisation(self, model, data, loss, epochs, optimizer_state):
        model.train()
        model_grad = self._calculate_grad(model, data, loss)
        optimizer = self.worker_optimizer(model.parameters(), **self.worker_optimizer_hparams)
        optimizer.load_state_dict(optimizer_state)
        batch_size = len(data) if self.batch_size == np.inf else self.batch_size
        dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
        for _ in range(epochs):
            for inputs, targets in dataloader:
                optimizer.zero_grad()
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                output = model(inputs)
                loss_value = loss(output, targets)
                loss_value.backward()
                model = self._add_proximal_term(model, self.initial_model_parameters)
                optimizer.step()
                del inputs, targets, output, loss_value
        return model, model_grad

    # @ray.remote
    # def _local_optimisation_ray(self, model, data, loss, epochs):
    #     local_model = model.to(self.device)
    #     optimizer = self.worker_optimizer(local_model.parameters(), **self.worker_optimizer_hparams)
    #     batch_size = len(data) if self.batch_size == np.inf else self.batch_size
    #     dataloader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=self.process_count)
    #     for _ in range(epochs):
    #         for inputs, targets in dataloader:
    #             optimizer.zero_grad()
    #             inputs, targets = inputs.to(self.device), targets.to(self.device)
    #             output = local_model(inputs)
    #             loss_value = loss(output, targets)
    #             loss_value.backward()
    #             optimizer.step()
    #     return local_model.state_dict()

    def __call__(self, model, data_list, loss, epochs):
        model.train()
        minibatch_size = len(data_list)
        self.initial_model_parameters = [param.clone().detach() for param in model.parameters()]
        optimizer = self.worker_optimizer(model.parameters(), **self.worker_optimizer_hparams)
        for _ in range(epochs):
            optimizer.zero_grad()
            new_model = deepcopy(model)
            for param in new_model.parameters():
                if param.requires_grad:
                    param.data *= 0
            if self.ray:
                raise NotImplementedError
                # futures = []
                # for data in data_list:
                #     local_model = deepcopy(model)
                #     futures += [self._local_optimisation_ray.remote(model=local_model, data=data, loss=loss, epochs=self.optimizer_steps)]
                #     del local_model
                # results = ray.get(futures)
                # for state_dict in results:
                #     local_model = deepcopy(model)
                #     local_model.load_state_dict(state_dict)
                #     for param, new_param in zip(local_model.parameters(), new_model.parameters()):
                #         if new_param.requires_grad:
                #             new_param.data.add_(param.data)
            else:
                for data in data_list:
                    local_model = deepcopy(model)
                    local_model, local_model_grad = self._local_optimisation(local_model, data, loss, self.worker_optimizer_steps, optimizer.state_dict())
                    with torch.no_grad():
                        for param, new_param in zip(local_model.parameters(), new_model.parameters()):
                            if new_param.requires_grad:
                                new_param.data += param.data.clone() / minibatch_size
                        for param, new_param in zip(local_model_grad.parameters(), model.parameters()):
                            if new_param.requires_grad:
                                if new_param.grad is None:
                                    new_param.grad = param.grad.clone() / minibatch_size
                                else:
                                    new_param.grad += param.grad.clone() / minibatch_size
                    del local_model, local_model_grad
            optimizer.step()
            with torch.no_grad():
                for param, new_param in zip(model.parameters(), new_model.parameters()):
                    if param.requires_grad:
                        param.data = new_param.data.clone()
            del new_model
        return model


def calculate_grad_norm(model):
    total_norms = []
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norms.append(param_norm.item() ** 2)
    total_norm = np.sqrt(np.sum(total_norms))
    return total_norm