"""Softly gradient-coupled Resnets.
"""

import copy
import argparse

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pytorch_lightning.metrics as metrics
import numpy as np
import pandas as pd

import functions.sgc_optim as sgc_optim

def float_interval(lower=None, upper=None):
    def fun(x):
        x = float(x)
        if lower is not None:
            assert x >= lower
        if upper is not None:
            assert x <= upper
        return x
    return fun

class SGCResNetModule(pl.LightningModule):

    def __init__(self, hparams, silent=True, reload=False):
        super().__init__()
        if isinstance(hparams, dict):
            hparams = argparse.Namespace(**hparams)
        if not silent:
            print(hparams)
        self.save_hyperparameters(hparams)
        recurrent_init = (hparams.init=='r') and not reload
        if not isinstance(hparams.n_channels, list):
            n_channels = hparams.n_channels
        elif len(hparams.n_channels) == 1:
            n_channels = hparams.n_channels[0]
        else:
            assert len(hparams.n_channels) == 3
            n_channels = hparams.n_channels
        try:
            in_channels = hparams.in_channels
        except AttributeError:
            in_channels = 3
        try:
            method = hparams.method
        except AttributeError:
            method = 'uniform'
        try:
            start_coupling_at = hparams.start_coupling_at
        except AttributeError:
            start_coupling_at = 0
        try:
            recurrent_batch_norm = hparams.recurrent_batch_norm
        except AttributeError:
            recurrent_batch_norm = False
        self.model = SGCResNet(
            stages=hparams.stages, n_blocks=hparams.n_blocks,
            n_channels=n_channels, in_channels=in_channels,
            recurrent_init=recurrent_init,
            coupling=hparams.coupling, method=method,
            unit_type=hparams.unit_type,
            start_coupling_at=start_coupling_at,
            recurrent_batch_norm=recurrent_batch_norm
        )
        try:
            self.decay_steps = hparams.decay_steps
        except AttributeError:
            self.decay_steps = [200, 300]
        try:
            self.lr = hparams.lr
        except AttributeError:
            self.lr = 0.1
        try:
            self.gamma = hparams.gamma
        except AttributeError:
            self.gamma = 0.1
        try:
            self.outputs = hparams.outputs
        except AttributeError:
            self.outputs = None


    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        if self.outputs is None:
            loss = F.cross_entropy(y_hat, y)
            acc = metrics.Accuracy()(y_hat, y)
        else:
            loss = nn.BCEWithLogitsLoss()(y_hat, y)
            acc = multi_class_accuracy(self.outputs, y_hat, y)
        result = pl.TrainResult(loss)
        result.log('train_loss', loss)
        result.log('train_acc', acc)
        return result

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        if self.outputs is None:
            loss = F.cross_entropy(y_hat, y)
            acc = metrics.Accuracy()(y_hat, y)
        else:
            loss = nn.BCEWithLogitsLoss()(y_hat, y)
            acc = multi_class_accuracy(self.outputs, y_hat, y)
        result = pl.EvalResult(checkpoint_on=1-acc)
        result.log('val_loss', loss)
        result.log('val_acc', acc)
        return result

    def configure_optimizers(self):
        optimizer = sgc_optim.CoupledSGD(
            self.model.parameters(), self.model, lr=self.lr, momentum=0.9, weight_decay=1e-4,
            coupling=self.model.coupling, method=self.model.coupling
        )
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.decay_steps, gamma=self.gamma)
        return [optimizer], [scheduler]

    def optimizer_step(self, current_epoch, batch_idx, optimizer, opt_idx,
                   lambda_closure, using_native_amp):
        # warm up lr
        if self.trainer.global_step == 0:
            for pg in optimizer.param_groups:
                pg['lr'] = 0.01
        if self.trainer.global_step == 400:
            for pg in optimizer.param_groups:
                pg['lr'] = self.lr

        # update params
        optimizer.step()
        optimizer.zero_grad()

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--stages', type=int, default=None)
        parser.add_argument('--n_blocks', type=int, default=8)
        parser.add_argument('--n_channels', type=int, nargs='+', default=[16])
        parser.add_argument('--coupling', type=float_interval(0,1), default=None)
        parser.add_argument('--init', choices=['r', 'nr'], default=None)
        parser.add_argument('--unit_type', choices=['preactivation', 'original'], default='preactivation')
        parser.add_argument('--in_channels', type=int, default=3)
        parser.add_argument('--decay_steps', nargs='+', type=int, default=[200, 300])
        parser.add_argument('--lr', type=float, default=0.1)
        parser.add_argument('--gamma', type=float, default=0.1)
        parser.add_argument('--outputs', type=int, default=None)
        parser.add_argument('--method', choices=['uniform', 'triangular'], default='uniform')
        parser.add_argument('--start_coupling_at', type=int, default=0)
        parser.add_argument('--recurrent_batch_norm', action='store_true')
        return parser

class SGCResNet(nn.Module):

    def __init__(self, stages, n_blocks, n_channels, in_channels=3,
                 n_classes=10,
                 recurrent_init=True, coupling=0., method='uniform',
                 unit_type='preactivation', start_coupling_at=0,
                 recurrent_batch_norm=False):
        """
        Args:
            unit_type (str): Type of residual unit.
                Choose between preactivation and original.
                Preactivation: BN -> ReLU -> Conv -> BN -> ReLU -> Conv -> Add
                (Recommended in He et al., 2016).
                Original: Conv -> BN -> ReLU -> Conv -> BN -> Add -> ReLU
                (Used in He et al., 2015).
        """
        super().__init__()
        block = ResidualBlock
        if not isinstance(n_blocks, list):
            n_blocks = [n_blocks]*stages
        if not isinstance(n_channels, list):
            n_channels = [
                n_channels*(2**it_stage) for it_stage in range(stages)
            ]
        self.conv1 = nn.Conv2d(
            in_channels, n_channels[0], kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(n_channels[0])
        if unit_type == 'preactivation':
            self.bn2 = nn.BatchNorm2d(n_channels[stages-1])
        stgs = []
        for it in range(stages):
            if it > 0:
                downsampling = block(
                    n_channels[it-1], n_channels[it], downsampling=True,
                    unit_type=unit_type
                )
                stgs.append(downsampling)
            if it == 0:
                first_stage = True
            else:
                first_stage = False
            stage = SGCStage(
                block, n_blocks[it], n_channels[it],
                recurrent_init=recurrent_init,
                unit_type=unit_type,
                first_stage=first_stage,
                start_coupling_at=start_coupling_at
            )
            stgs.append(stage)
        self.stages = nn.ModuleList(stgs)
        self.linear = nn.Linear(n_channels[stages-1], n_classes)
        self._init_weights(recurrent_batch_norm)
        self.coupling = coupling
        self.unit_type = unit_type
        self.method = method

    def _init_weights(self, recurrent_batch_norm=False):
        for weight in [self.conv1, self.linear]:
            nn.init.kaiming_normal_(weight.weight, nonlinearity='relu', mode='fan_out')
        for stage in self.stages:
            stage._init_weights(recurrent_batch_norm=recurrent_batch_norm)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        for stage in self.stages:
            x = stage(x)
        if self.unit_type == 'preactivation':
            x = F.relu(self.bn2(x))
        x = F.avg_pool2d(x, x.size()[3])
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

    def couple_gradients(self, coupling, method):
        if coupling == 0:
            return
        for stage in self.stages:
            if hasattr(stage, 'couple_gradients'):
                if coupling == 1:
                    stage.couple_gradients(1, 'uniform')
                else:
                    stage.couple_gradients(coupling, method)

    def effective_parameters(self):
        df = pd.DataFrame({
            'nonrec_summary': [],
            'first_weights': [],
            'mean': [],
            'dev': []
        })
        i = 0
        for stage in self.stages:
            if isinstance(stage, SGCStage):
                new_entry = stage.effective_parameters()
                new_entry['stage'] = np.array([i])
                df = df.append(pd.DataFrame(new_entry))
                i += 1
        return df

class SGCStage(nn.Module):

    def __init__(self, block, n_blocks, n_channels, recurrent_init=True,
                 unit_type='preactivation', first_stage=False,
                 start_coupling_at=0):
        super().__init__()
        lst = []
        for it in range(n_blocks):
            if (it == 0) and first_stage:
                first_block = True
            else:
                first_block = False
            lst.append(block(n_channels, n_channels, unit_type=unit_type,
                             first_block=first_block))
        self.blocks = nn.ModuleList(lst)
        self.recurrent_init = recurrent_init
        self.first_stage = first_stage
        self.start_coupling_at = start_coupling_at

    def _init_weights(self, recurrent_batch_norm=False):
        if self.recurrent_init:
            self.blocks[0]._init_weights(recurrent_batch_norm=recurrent_batch_norm)
            for it in range(1, len(self.blocks)):
                self.blocks[it]._init_weights(self.blocks[0],
                    recurrent_batch_norm=recurrent_batch_norm)
        else:
            for block in self.blocks:
                block._init_weights()

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

    @property
    def depth(self):
        return len(self.blocks)

    def couple_gradients(self, coupling, method):
        shared_weight = coupling/(coupling+self.depth*(1-coupling))
        if self.first_stage and (self.start_coupling_at==0):
            coupling_itr = zip(*[
                self.blocks[d].get_parameters(subset='first') for d in range(self.depth)
            ])
            coupling_itr_2 = zip(*[
                self.blocks[d].get_parameters(subset='not_first') for d in range(1, self.depth)
            ])
        else:
            coupling_itr = zip(*[
                self.blocks[d].get_parameters(subset='all') for d in range(self.start_coupling_at, self.depth)
            ])
        for coupled_params in coupling_itr:
            if method=='uniform':
                gradsum = sum([
                    coupled_params[d].grad for d in range(self.depth-self.start_coupling_at)
                ]).detach()
                for d in range(self.depth-self.start_coupling_at):
                    coupled_params[d].grad = shared_weight*gradsum +\
                            (1-shared_weight)*coupled_params[d].grad
            if method=='triangular':
                stacked_grad = torch.stack([
                    coupled_params[d].grad for d in range(self.depth)
                ]).detach()
                coupling_mat = self.coupling_matrix(coupling, method)
                coupled_grad = (
                    coupling_mat.reshape(
                        *coupling_mat.shape,
                        *[1]*len(stacked_grad.shape[1:])
                    ) *
                    stacked_grad.unsqueeze(0)
                ).sum(dim=1)
                for d in range(self.depth):
                    coupled_params[d].grad = coupled_grad[d]
        if self.first_stage and (self.start_coupling_at==0):
            for coupled_params in coupling_itr_2:
                if method=='uniform':
                    gradsum = sum([
                        coupled_params[d-1].grad for d in range(1, self.depth)
                    ]).detach()
                    for d in range(1, self.depth):
                        coupled_params[d-1].grad = shared_weight*gradsum +\
                                (1-shared_weight)*coupled_params[d-1].grad
                if method=='triangular':
                    stacked_grad = torch.stack([
                        coupled_params[d-1].grad for d in range(1, self.depth)
                    ]).detach()
                    coupling_mat = self.coupling_matrix(coupling, method)[1:, 1:]
                    coupled_grad = (
                        coupling_mat.reshape(
                            *coupling_mat.shape,
                            *[1]*len(stacked_grad.shape[1:])
                        ) *
                        stacked_grad.unsqueeze(0)
                    ).sum(dim=1)
                    for d in range(1, self.depth):
                        coupled_params[d-1].grad = coupled_grad[d-1]

    def coupling_matrix(self, coupling, method):
        if coupling >= 0.5:
            coupling_ = 2*(1-coupling)/self.depth
        else:
            coupling_ = 1/(2*coupling*self.depth)
        if method == 'triangular':
            rows = torch.tensor([[
                d_1 for d_1 in range(self.depth)
            ] for d_2 in range(self.depth)]).float()
            cols = torch.tensor([[
                d_2 for d_1 in range(self.depth)
            ] for d_2 in range(self.depth)]).float()
            mat = 1-coupling_*torch.abs(cols-rows)
            mat = torch.max(mat, torch.tensor(0.))
            return mat

    def effective_parameters(self):
        if self.first_stage and self.start_coupling_at == 0:
            start_rec = 1
        rec_conv1_weights = torch.stack([
            block.conv1.weight for block in self.blocks[(self.start_coupling_at+1):]
        ])
        rec_conv2_weights = torch.stack([
            block.conv2.weight for block in self.blocks[(self.start_coupling_at+1):]
        ])
        rec_bn_weights = torch.stack([
            torch.cat([block.bn1.weight, block.bn1.bias,
                       block.bn2.weight, block.bn2.bias])\
            for block in self.blocks[(self.start_coupling_at+1):]
        ])
        rec_weights = torch.stack([rec_conv1_weights, rec_conv2_weights])
        nonrec_spread = torch.numel(torch.stack([
            block.conv1.weight for block in self.blocks[:(self.start_coupling_at+1)]
        ])) + torch.numel(torch.stack([
            torch.cat([block.bn1.weight, block.bn1.bias])\
            for block in self.blocks[:(self.start_coupling_at+1)]
        ]))
        if self.first_stage:
            id_start = 1
        else:
            id_start = 0
        if id_start < self.start_coupling_at+1:
            nonrec_spread += torch.numel(torch.stack([
                block.conv2.weight for block in self.blocks[id_start:(self.start_coupling_at+1)]
            ])) + torch.numel(torch.stack([
                torch.cat([block.bn2.weight, block.bn2.bias])\
                for block in self.blocks[id_start:(self.start_coupling_at+1)]
            ]))
        conv_mean = rec_weights.mean(dim=1, keepdim=True)
        conv_dev = (rec_weights-conv_mean)/torch.sqrt(torch.mean((rec_weights-torch.mean(rec_weights))**2))
        bn_mean = rec_bn_weights.mean(dim=0, keepdim=True)
        bn_dev = (rec_bn_weights-bn_mean)/torch.sqrt(torch.mean((rec_bn_weights-torch.mean(rec_bn_weights))**2))
        rtn = {
            'nonrec_summary': np.array(
                (torch.numel(rec_weights) + nonrec_spread)
                ),
            'first_weights': np.array(nonrec_spread),
            'mean': np.array(torch.numel(conv_mean) + torch.numel(bn_mean)),
            'dev': np.array(
                torch.sum(self.depth/(self.depth-1)*conv_dev.abs()).detach().cpu().numpy() +\
                torch.sum(self.depth/(self.depth-1)*bn_dev.abs()).detach().cpu().numpy()
            )
        }
        return rtn

def l1(x):
    return torch.sum(torch.abs(x))


class ResidualBlock(nn.Module):

    def __init__(self, in_channels, channels, downsampling=False,
                 unit_type='preactivation', first_block=False):
        super().__init__()
        if downsampling:
            stride = 2
        else:
            stride = 1
        self.conv1 = nn.Conv2d(
            in_channels, channels, kernel_size=3, stride=stride, padding=1,
            bias=False
        )
        if unit_type == 'original':
            self.bn1 = nn.BatchNorm2d(channels)
        elif unit_type == 'preactivation':
            self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(
            channels, channels, kernel_size=3, stride=1, padding=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(channels)
        if downsampling:
            self.shortcut = LambdaLayer(
                lambda x: F.pad(
                    x[:, :, ::2, ::2],
                    (0, 0, 0, 0, (channels-in_channels)//2, (channels-in_channels)//2),
                    "constant", 0
                )
            )
        else:
            self.shortcut = nn.Identity()
        self.unit_type = unit_type
        self.first_block = first_block

    def _init_weights(self, block=None, recurrent_batch_norm=False):
        if block is None:
            for conv in [self.conv1, self.conv2]:
                nn.init.kaiming_normal_(conv.weight, nonlinearity='relu', mode='fan_out')
        else:
            for conv, copy_conv in zip(
                [self.conv1, self.conv2],
                [block.conv1, block.conv2]
            ):
                conv.weight = copy.copy(copy_conv.weight)
        if recurrent_batch_norm:
            nn.init.constant_(self.bn1.weight, 0.1)
            if not ((self.unit_type == 'preactivation') and self.first_block):
                nn.init.constant_(self.bn2.weight, 0.1)

    def forward(self, x):
        if self.unit_type == 'original':
            out = F.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
            out += self.shortcut(x)
            out = F.relu(out)
        elif self.unit_type == 'preactivation':
            if self.first_block:
                out = self.conv1(x)
            else:
                out = self.conv1(F.relu(self.bn1(x)))
            out = self.conv2(F.relu(self.bn2(out)))
            out += self.shortcut(x)
        return out

    def get_parameters(self, subset='all'):
        if subset == 'all':
            return self.parameters()
        if subset == 'first':
            return [*self.conv1.parameters(), *self.bn2.parameters(), *self.conv2.parameters()]
        if subset == 'not_first':
            if self.first_block:
                raise NotImplementedError()
            return self.bn1.parameters()
        raise NotImplementedError()

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)

def multi_class_accuracy(n_classes, outputs, target):
    _, outputs = torch.topk(outputs, n_classes, dim=1)
    _, digits = torch.topk(target, n_classes, dim=1)
    return torch.eq(torch.sort(digits)[0], torch.sort(outputs)[0]).all(dim=1).float().mean()

def test():
    net = sgc.SGCResNet(3, sgc.ResidualBlock, 8, 16)
    weight = net.stages[0].blocks[0].conv1.weight
    for i in range(1, 8):
        assert torch.all(weight == net.stages[0].blocks[i].conv1.weight)
