"""Implement regression variants."""

import torch
from generalized_model import GeneralizedDARTS
from regression_ops import OPS, Genotype


class MixedOp(torch.nn.Module):
    CUTOFF_VALUE: torch.jit.Final[float] = 1e-6

    def __init__(self, A, sublayers=1, channels=1, operators=None, norm='none'):
        super().__init__()

        self._ops = torch.nn.ModuleList()
        self.names = []
        allowed_ops = operators
        if norm == 'renorm':
            renorm = torch.nn.BatchNorm2d(channels, affine=False, track_running_stats=False)

        for primitive in allowed_ops:
            op = OPS[primitive](A, sublayers, channels)
            if norm == 'DARTS':
                op = torch.nn.Sequential(op, torch.nn.BatchNorm2d(channels, affine=False))
            elif norm == 'renorm':
                op = torch.nn.Sequential(op, renorm)
            self._ops.append(op)
            self.names.append(primitive)

    def forward(self, x, y, x0, weights):
        outputs = torch.zeros_like(x)
        for idx, op in enumerate(self._ops):
            # if weights[idx] > self.CUTOFF_VALUE:
            outputs += weights[idx] * op(x, y, x0)
        return outputs

    def reset_parameters(self):
        """Propagate parameter resets to all ops."""
        for op in self._ops:
            if isinstance(op, torch.nn.Sequential):
                [o.reset_parameters() for o in op]
            else:
                op.reset_parameters()


class RegressionSearchSpace(GeneralizedDARTS):
    """Simple multi-layered choice of prox/grad operations."""

    def _initialize_layers(self, operator, *args, genotype=None, layers=2, sublayers=2, channels=1, operators=None,
                           deep_supervised=False, norm='none', update=None, dataset=None, randomized_init=False, noise_level=None):
        """Initialize the model layers here."""
        self.operator = operator.to(**self.setup)
        self.operators = operators
        self.dataset = dataset

        self.deep_supervised = deep_supervised
        self.randomized_init = randomized_init
        self.layers = layers
        self.noise_level = noise_level

        self.learned_ops = torch.nn.ModuleList()
        for _ in range(layers):
            self.learned_ops.append(MixedOp(self.operator, sublayers,
                                            channels, operators, norm=norm))

        if operator.channels != channels:
            _Conv = torch.nn.Conv2d if operator.dimension == 2 else torch.nn.Conv1d
            self.projection_input = _Conv(operator.channels, channels,
                                          kernel_size=3, padding=1, bias=True)
            self.nonlin = torch.nn.ReLU()
            self.projection_output = _Conv(channels, operator.channels,
                                           kernel_size=3, padding=1, bias=True)
        else:
            self.projection_input, self.projection_output, self.nonlin = torch.nn.Identity(
            ), torch.nn.Identity(), torch.nn.Identity()

        # torch.nn.L1Loss()  # torch.nn.MSELoss(reduction='mean')
        self.criterion = torch.nn.MSELoss(reduction='mean')
        if update == 'liu-single':
            if 'learnable_grad' and 'net' in operators or 'learnable_grad_proj_b' and 'net' in operators:
                self.reset_weights()

    def _forward_implementation(self, y, weights, x_true=None):
        """Implement the forward pass here, assume weights are probabilities."""
        x0 = self.operator.T(y)
        x = self.nonlin(self.projection_input(x0))
        if self.deep_supervised:
            aux_loss = 0
            for weight, op in zip(weights, self.learned_ops):
                x = op(x, y, x0, weight[0])
                if x_true is not None:
                    aux_loss += self.criterion(x0 + self.projection_output(self.relu(x)), x_true)
            return x0 + self.projection_output(self.relu(x)), aux_loss

        else:
            for weight, op in zip(weights, self.learned_ops):
                x = op(x, y, x0, weight[0])
            x = x0 + self.projection_output(self.nonlin(x))
            return x, None

    @torch.no_grad()
    def project_onto_constraint(self):
        """Project onto ||theta_k|| < alpha_k, keeping alpha fixed."""
        pass

    @torch.no_grad()
    def reset_weights(self):
        """Initialize weights if network parameter already at optimum (parameter warm start)"""
        # Assumption:
        # model.learned_ops[0]._ops[0].__class__.__name__ == 'RegularizedGradDescentStep'
        # model.learned_ops[0]._ops[1].__class__.__name__ == 'DeepStep'
        if self.noise_level == 0 and self.dataset == 'Blur':
            learn_grad_parameters = torch.load(
                'state_dicts/' + self.dataset + '_00/model_learn_grad_parameter.obj')
            learn_grad_buffers = torch.load(
                'state_dicts/' + self.dataset + '_00/model_learn_grad_buffers.obj')
            net_parameters = torch.load('state_dicts/' + self.dataset + '_00/model_net_parameter.obj')
            net_buffers = torch.load('state_dicts/' + self.dataset + '_00/model_net_buffers.obj')
        else:
            learn_grad_parameters = torch.load(
                'state_dicts/' + self.dataset + '/model_learn_grad_parameter.obj')
            learn_grad_buffers = torch.load(
                'state_dicts/' + self.dataset + '/model_learn_grad_buffers.obj')
            net_parameters = torch.load('state_dicts/' + self.dataset + '/model_net_parameter.obj')
            net_buffers = torch.load('state_dicts/' + self.dataset + '/model_net_buffers.obj')
        if self.dataset == 'Blur' or self.dataset == 'Downsampling' or self.dataset == 'Cosine':
            learn_grad_parameters = learn_grad_parameters[:-4]
            net_parameters = net_parameters[:-4]
            optimal_parameters = [[*learn_grad_parameters[4 * (i - 1):4 * i], *net_parameters[4 * (i - 1):4 * i]]  # for #layer=10
                                  for i in range(1, len(learn_grad_parameters) // 4 + 1)]
            optimal_buffers = [[*learn_grad_buffers[4 * (i - 1):4 * i], *net_buffers[4 * (i - 1):4 * i]]
                               for i in range(1, len(learn_grad_parameters) // 4 + 1)]
        else:
            learn_grad_parameters = learn_grad_parameters[:-4]
            net_parameters = net_parameters[:-4]
            if self.dataset != 'BSDS-Denoising':
                learn_grad_buffers = learn_grad_buffers[2:]
                net_buffers = net_buffers[2:]

            optimal_parameters = [[*learn_grad_parameters[5 * (i - 1):5 * i], *net_parameters[5 * (i - 1):5 * i]]  # for #layer=7
                                  for i in range(1, len(learn_grad_parameters) // 5 + 1)]
            optimal_buffers = [[*learn_grad_buffers[3 * (i - 1):3 * i], *net_buffers[3 * (i - 1):3 * i]]
                               for i in range(1, len(learn_grad_buffers) // 3 + 1)]

        optimal_parameters = [item for sublist in optimal_parameters for item in sublist]
        optimal_buffers = [item for sublist in optimal_buffers for item in sublist]

        for param, old_state in zip(self.parameters(), optimal_parameters):
            param.copy_(old_state)
        if self.dataset == 'BSDS-Downsampling' or self.dataset == 'BSDS-Blur':
            buf = [p.detach().clone() for p in self.buffers()]
            optimal_buffers = [buf[0], buf[1], *optimal_buffers]
            for buffer, old_state in zip(self.buffers(), optimal_buffers):
                buffer.copy_(old_state)
        elif self.dataset == 'Blur' or self.dataset == 'Downsampling' or self.dataset == 'Cosine':
            pass
        else:
            for buffer, old_state in zip(self.buffers(), optimal_buffers):
                buffer.copy_(old_state)

    def _initialize_alphas(self):
        """Initialize alphas here."""
        alphas = []
        for op in self.learned_ops:
            op_length = len(op._ops)
            if self.randomized_init:
                vals = torch.rand((1, op_length))
                alphas.append(vals / vals.sum())
            else:
                alphas.append(torch.ones((1, op_length)) / op_length)
        return alphas

    def reset_parameters(self):
        """Re-call all init hooks of all modules."""
        for module in self.learned_ops:
            module.reset_parameters()

    def genotype(self):
        operations = []
        for alpha in self.arch_parameters():
            operation_name = self.operators[alpha.argmax(dim=-1)]
            operations.append(operation_name)
        return Genotype(operations=operations)


class RegressionCell(torch.nn.Module):
    """Broad regression cell with single input, single output states.

    * Depthwise concatenation
    * Depthwise Conv for simplification
    * I/O channel count is fixed and reduced
    """

    def __init__(self, steps, multiplier, operator, sublayers=2, channels=1, operators=None, norm=None):
        super().__init__()
        self._steps = steps
        self._multiplier = multiplier

        self._ops = torch.nn.ModuleList()
        for i in range(self._steps):
            for j in range(1 + i):
                op = MixedOp(operator, sublayers, channels, operators, norm=norm)
                self._ops.append(op)

        _Conv = torch.nn.Conv2d if operator.dimension == 2 else torch.nn.Conv1d
        self.nonlin = torch.nn.ReLU()
        self.reduction = _Conv(channels * self._multiplier, channels, kernel_size=1, padding=0, bias=True)

    def forward(self, inputs, data, backprojection, weights):
        states = [inputs]
        offset = 0
        for i in range(self._steps):
            s = 0
            for j, h in enumerate(states):
                s += self._ops[offset + j](h, data, backprojection, weights[offset + j])
            offset += len(states)
            states.append(s)
        concat_output = torch.cat(states[-self._multiplier:], dim=1)
        reduced_output = self.reduction(self.nonlin(concat_output))
        return reduced_output


class BroadRegressionSearchSpace(RegressionSearchSpace):
    """Simple multi-layered choice of prox/grad operations. DARTS_Single not implemented."""

    def _initialize_layers(self, operator, *args, genotype=None, layers=3, sublayers=2, channels=1, operators=None,
                           steps=3, multiplier=2, deep_supervised=False, norm='none', update=None, dataset=None,
                           randomized_init=False, noise_level=None):
        """Initialize the model layers here."""
        self.operator = operator.to(**self.setup)
        self.operators = [*operators, 'zero']
        self.dataset = dataset

        self.deep_supervised = deep_supervised
        self.randomized_init = randomized_init
        self.layers = layers
        self.steps = steps
        self.noise_level = noise_level

        self.cells = torch.nn.ModuleList()
        for _ in range(layers):
            self.cells.append(RegressionCell(steps, multiplier, self.operator, sublayers=sublayers, channels=channels,
                                             operators=self.operators, norm=norm))

        if operator.channels != channels:
            _Conv = torch.nn.Conv2d if operator.dimension == 2 else torch.nn.Conv1d
            self.projection_input = _Conv(operator.channels, channels,
                                          kernel_size=3, padding=1, bias=True)
            self.nonlin = torch.nn.ReLU()
            self.projection_output = _Conv(channels, operator.channels,
                                           kernel_size=3, padding=1, bias=True)
        else:
            self.projection_input, self.projection_output, self.nonlin = torch.nn.Identity(
            ), torch.nn.Identity(), torch.nn.Identity()

        # torch.nn.L1Loss()  # torch.nn.MSELoss(reduction='mean')
        self.criterion = torch.nn.MSELoss(reduction='mean')

    def _forward_implementation(self, y, weights, x_true=None):
        """Implement the forward pass here, assume weights are probabilities."""
        x0 = self.operator.T(y)
        x = self.nonlin(self.projection_input(x0))
        if self.deep_supervised:
            aux_loss = 0
            for weight, cell in zip(weights, self.cells):
                x = cell(x, y, x0, weight)
                if x_true is not None:
                    aux_loss += self.criterion(x0 + self.projection_output(self.relu(x)), x_true)
            return x0 + self.projection_output(self.relu(x)), aux_loss

        else:
            for weight, cell in zip(weights, self.cells):
                x = cell(x, y, x0, weight)
            x = x0 + self.projection_output(self.nonlin(x))
            return x, None

    def _initialize_alphas(self):
        """Initialize alphas here."""
        k = sum(1 for i in range(self.steps) for n in range(2 + i))
        op_length = len(self.operators)

        alphas = []
        for cell in self.cells:
            if self.randomized_init:
                vals = torch.rand((k, op_length))
                alphas.append(vals / vals.sum())
            else:
                alphas.append(torch.ones((k, op_length)) / op_length)
        return alphas

    def genotype(self):

        def _parse(weights):
            gene = []
            n = 2
            start = 0
            for i in range(self.steps):
                end = start + n
                W = weights[start:end].copy()
                edges = sorted(range(i + 2), key=lambda x: - max(W[x][k] for k in range(len(W[x]))
                                                                 if k != self.operators.index('zero')))[:2]
                for j in edges:
                    k_best = None
                    for k in range(len(W[j])):
                        if k != self.operators.index('zero'):
                            if k_best is None or W[j][k] > W[j][k_best]:
                                k_best = k
                    gene.append((self.operators[k_best], j))
                start = end
                n += 1
            return gene

        chromosomes = dict()
        for idx, cell_probs in enumerate(self.return_probabilities()):
            chromosomes[f'Cell {idx}'] = _parse(cell_probs.data.cpu().numpy())
        return Genotype(operations=chromosomes)


class BroadRegressionSearchSpaceSignalRecovery(BroadRegressionSearchSpace):
    """Simple multi-layered choice of prox/grad operations. DARTS_Single not implemented."""

    def _initialize_layers(self, operator, *args, genotype=None, layers=3, sublayers=2, channels=1, operators=None,
                           steps=3, multiplier=2, deep_supervised=False, norm='none', update=None, dataset=None,
                           randomized_init=False, noise_level=None):
        """Initialize the model layers here."""
        self.operator = operator.to(**self.setup)
        self.operators = [*operators, 'zero']
        self.dataset = dataset

        self.deep_supervised = deep_supervised
        self.randomized_init = randomized_init
        self.layers = layers
        self.steps = steps
        self.noise_level = noise_level

        self.cells = torch.nn.ModuleList()
        for _ in range(layers):
            self.cells.append(RegressionCell(steps, multiplier, self.operator, sublayers=sublayers, channels=channels,
                                             operators=self.operators, norm=norm))

        if operator.channels != channels:
            _Conv = torch.nn.Conv2d if operator.dimension == 2 else torch.nn.Conv1d
            self.projection_input = _Conv(operator.channels, channels,
                                          kernel_size=3, padding=1, bias=True)
            self.nonlin = torch.nn.ReLU()
            self.projection_output = _Conv(channels, operator.channels,
                                           kernel_size=3, padding=1, bias=True)
        else:
            self.projection_input, self.projection_output, self.nonlin = torch.nn.Identity(
            ), torch.nn.Identity(), torch.nn.Identity()

        # torch.nn.L1Loss()  # torch.nn.MSELoss(reduction='mean')
        self.criterion = torch.nn.MSELoss(reduction='mean')

    def _forward_implementation(self, y, weights, x_true=None):
        """Implement the forward pass here, assume weights are probabilities."""
        x0 = torch.zeros_like(y)
        x = self.nonlin(self.projection_input(x0))
        if self.deep_supervised:
            aux_loss = 0
            for weight, cell in zip(weights, self.cells):
                x = cell(x, y, x0, weight)
                if x_true is not None:
                    aux_loss += self.criterion(self.projection_output(self.relu(x)), x_true)
            return self.projection_output(self.relu(x)), aux_loss

        else:
            for weight, cell in zip(weights, self.cells):
                x = cell(x, y, x0, weight)
            x = self.projection_output(self.nonlin(x))
            return x, None
