from tqdm import tqdm

import torch

from .utils import check_proxymodel_valid
from backpack import extend


class ProxyBase():

    def __init__(self, proxy_model, optimizer, batch_size,
                 num_workers_per_dataset, device, **kwargs):
        super(ProxyBase, self).__init__()

        check_proxymodel_valid(proxy_model)

        self.batch_size = batch_size
        self._num_workers_per_dataset = num_workers_per_dataset

        self.optimizer = optimizer

        self._prox_num = 0
        self._total_iteration = 0


        if device == 'cuda' and torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        proxy_model = extend(proxy_model)
        self.proxy_model = proxy_model.to(device=self.device,
                                          non_blocking=True)

    def _get_data(self, train=True):
        raise NotImplementedError()

    def _load_generator_statedicts(self, generator_statedicts):
        raise NotImplementedError()

    def train(self, generators, num_iterations):

        self._prox_num += 1
        local_iteration = 0
        self.proxy_model.train()

        stop = False

        num_params = sum(param.numel()
                         for param in self.proxy_model.parameters())
        desc = "Proxy ({}) Training (N_parameters: {})".format(self._prox_num,
                                                               num_params)
        with tqdm(desc=desc, total=num_iterations,
                  ascii=True, leave=False) as pbar:
            while not stop:

                self.proxy_model.train()
                data_train = self._get_data(train=True)

                # calculate loss that is *minimized*
                self.optimizer.zero_grad()
                loss = self._loss(generators, data_train)

                loss.backward()
                self.optimizer.step()
                self._manipulate_weights(self.proxy_model.parameters())

                # update various log values
                self._total_iteration += 1
                local_iteration += 1

                pbar.update(1)
                # stop training
                if local_iteration >= num_iterations:
                    # minus the loss as we maximuze it
                    # (but treated it as a minimzation problem)
                    return -loss

    def _loss(self, generators, data):
        raise NotImplementedError()

    def _manipulate_weights(self, parameters):
        pass
