#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import numpy as np
from helpers import get_sparse_p
from helpers import infer_log_w, align_w_maps
from losses import pred_loss, pred_loss_shuffle, pred_loss_pos, pred_loss_pos2, pred_loss_pos3


class PredsegModule(torch.nn.Module):
    """single module with neighbor prediction based on the output """

    def __init__(self, raw_module, neighbors, out_dim, normalize_output=False):
        super(PredsegModule, self).__init__()
        self.raw_module = raw_module
        self.neighbors = np.array(neighbors)
        self.normalize_output = normalize_output
        self.log_c = torch.nn.Parameter(torch.Tensor(
            2 * np.ones((self.neighbors.shape[0],
                     out_dim)) - np.log(out_dim)))
        self.register_parameter('log_c', self.log_c)
        self.prior_w = torch.nn.Parameter(torch.Tensor(
            np.zeros(self.neighbors.shape[0])))
        self.register_parameter('prior_w', self.prior_w)
        self.feat = None
        self.norm = torch.nn.InstanceNorm2d(out_dim)

    def forward(self, x):
        """forward pass is as in raw_module, but saving the output as
        self.feat """
        out = self.raw_module(x)
        self.feat = self.norm(out)
        if self.normalize_output:
            output = self.feat
        else:
            output = out
        return output

    def get_loss(self, noise_dist='batch', noise=0, **kwargs):
        """Here we return the loss computed from features"""
        feat = torch.sqrt(torch.tensor(1 - noise**2)) * self.feat \
            + noise * torch.randn_like(self.feat)
        if noise_dist == 'batch':
            loss = pred_loss(
                feat, self.neighbors, self.prior_w,
                self.log_c, prec=1, device=self.feat.device, **kwargs)
        elif noise_dist == 'shuffle':
            loss = pred_loss_shuffle(
                feat, self.neighbors, self.prior_w,
                self.log_c, prec=1, device=self.feat.device, **kwargs)
        elif noise_dist == 'pos':
            loss = pred_loss_pos(
                feat, self.neighbors, self.prior_w,
                self.log_c, prec=1, device=self.feat.device, **kwargs)
        elif noise_dist == 'pos2':
            loss = pred_loss_pos2(
                feat, self.neighbors, self.prior_w,
                self.log_c, prec=1, device=self.feat.device, **kwargs)
        elif noise_dist == 'pos3':
            loss = pred_loss_pos3(
                feat, self.neighbors, self.prior_w,
                self.log_c, prec=1, device=self.feat.device, **kwargs)
        return loss

    def grad_acc(self, noise_dist='pos', n_acc=1, retain_graph=False,
                 noise=0, **kwargs):
        """ apply a loss and run the backward gradient computation
        by accumulating the loss gradient at the feature map level """
        feat = self.feat.detach()
        feat.requires_grad = True
        l_report = 0
        for _ in range(n_acc):
            feat_n = torch.sqrt(torch.tensor(1 - noise**2)) * feat \
                + noise * torch.randn_like(feat)
            if noise_dist == 'batch':
                loss = pred_loss(
                    feat_n, self.neighbors, self.prior_w,
                    self.log_c, prec=1, device=self.feat.device, **kwargs)
            elif noise_dist == 'shuffle':
                loss = pred_loss_shuffle(
                    feat_n, self.neighbors, self.prior_w,
                    self.log_c, prec=1, device=self.feat.device, **kwargs)
            elif noise_dist == 'pos':
                loss = pred_loss_pos(
                    feat_n, self.neighbors, self.prior_w,
                    self.log_c, prec=1, device=self.feat.device, **kwargs)
            elif noise_dist == 'pos2':
                loss = pred_loss_pos2(
                    feat_n, self.neighbors, self.prior_w,
                    self.log_c, prec=1, device=self.feat.device, **kwargs)
            elif noise_dist == 'pos3':
                loss = pred_loss_pos3(
                    feat_n, self.neighbors, self.prior_w,
                    self.log_c, prec=1, device=self.feat.device, **kwargs)
            loss.backward()
            l_report += float(loss.detach().cpu().numpy())
        if self.feat.requires_grad:
            gradient = feat.grad.detach()
            loss_feat = torch.sum(self.feat * gradient)
            loss_feat.backward(retain_graph=retain_graph)
        return l_report

    def infer_w(self, shifts=[[0, 0]], subsamplings=[1], resolution=None,
                interpolate=False):
        """infers the log-probability ratio between w=1 and w=0"""
        w_map = infer_log_w(self.feat, self.neighbors, self.prior_w, self.log_c, prec=1)
        if resolution is None:
            resolution = w_map.shape[-2:] + 2 * np.array(shifts[0])
        w_map, neighbors = align_w_maps(
            [w_map], self.neighbors, resolution, shifts, subsamplings,
            interpolate=interpolate)
        return w_map, self.neighbors, resolution

    def get_sparse_w_matrix(self, shifts, subsamplings, resolution=None):
        """ produces a list of sparse connectivity matrices"""
        w_map, neighbors, resolution = self.infer_w(shifts, subsamplings, resolution)
        sparse_w = [get_sparse_p(w_m, neighbors)
                    for w_m in w_map]
        shapes = [w_m.shape for w_m in w_map]
        return sparse_w, shapes

    def get_pred_pars(self):
        return [self.log_c, self.prior_w]

    def get_other_pars(self):
        return [p for p in self.raw_module.parameters()]

    def p_modules(self):
        return [self]


class PredsegSequential(torch.nn.Sequential):

    def __init__(self, *modules):
        """
        This is a variant of the sequential model which contains
        predseg modules. It then allows getting overall connectivities
        and losses.
        Thus, to use this: first generate your layers including the predseg
        modules and then call this on them instead of torch.nn.Sequential
        to bind them into one module.
        """
        super(PredsegSequential, self).__init__(*modules)

    def p_modules(self):
        """ returns a list of all predseg submodules """
        p_m = []
        for m in self.modules():
            if isinstance(m, PredsegModule):
                p_m.append(m)
        return p_m

    def get_loss(self, noise_dist='batch', **kwargs):
        """
        Here we return the summed loss computed from all predseg submodules
        """
        loss = 0
        for m in self.modules():
            if isinstance(m, PredsegModule):
                loss += m.get_loss(noise_dist=noise_dist, **kwargs)
        return loss

    def grad_acc(self, noise_dist='pos', **kwargs):
        """ apply a loss and run the backward gradient computation
        by accumulating the loss gradient at the feature map level"""
        l_report = 0
        p_m = self.p_modules()
        for m in p_m[:-1]:
            l_report += m.grad_acc(noise_dist=noise_dist, retain_graph=True, **kwargs)
        l_report += p_m[-1].grad_acc(noise_dist=noise_dist, retain_graph=False, **kwargs)
        return l_report

    def infer_w(self, shifts, subsamplings, resolution=None, interpolate=False):
        w_maps = []
        neighbors = []
        for m in self.modules():
            if isinstance(m, PredsegModule):
                w_m, neigh, _ = m.infer_w()
                w_maps.append(w_m)
                neighbors.append(neigh)
        if resolution is None:
            resolution = w_maps[0].shape[-2:] + 2 * np.array(shifts[0])
        w_maps, neighbors = align_w_maps(
            w_maps, neighbors, resolution, shifts, subsamplings,
            interpolate=interpolate)
        return w_maps, neighbors, resolution

    def infer_w_sep(self, shifts, subsamplings, resolution=None):
        w_maps = []
        neighbors = []
        resolutions = []
        for m in self.modules():
            if isinstance(m, PredsegModule):
                w_m, neigh, res = m.infer_w()
                w_maps.append(w_m)
                neighbors.append(neigh)
                resolutions.append(res)
        return w_maps, neighbors, resolutions

    def get_sparse_w_matrix(self, shifts, subsamplings, resolution=None,
                            interpolate=False):
        """ produces a list of sparse connectivity matrices"""
        w_map, neighbors, resolution = self.infer_w(
            shifts, subsamplings, resolution=resolution, interpolate=interpolate)
        sparse_w = [get_sparse_p(w_m, neighbors)
                    for w_m in w_map]
        return sparse_w, resolution

    def get_sparse_w_matrix_sep(self, shifts, subsamplings, resolution=None):
        """ produces a list of sparse connectivity matrices"""
        w_maps, neighbors, resolutions = self.infer_w_sep(
            shifts, subsamplings, resolution=resolution)
        sparse_ws = [[get_sparse_p(w_m, neigh)
                      for w_m in w_map]
                     for w_map, neigh in zip(w_maps, neighbors)]
        return sparse_ws, resolutions

    def get_pred_pars(self, device=None):
        """
        returns the parameters for the spatial prediction
        """
        pars = []
        for m in self.modules():
            if isinstance(m, PredsegModule):
                pars += m.get_pred_pars()
        return pars

    def get_other_pars(self, device=None):
        """
        returns the parameters for the spatial prediction
        """
        pars = []
        for m in self.modules():
            if not isinstance(m, PredsegModule):
                pars += [p for p in m.parameters(recurse=False)]
        return pars
