import torch
import torch.nn as nn
from torch.nn import init

import math
import numpy as np

import torch.nn as nn
import torch
import numpy as np 
import pytorch_lightning as pl

import matplotlib as mpl
import matplotlib.pyplot as plt
from flow_utils import *
import torch.nn.functional as F

def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
    if batchNorm:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.LeakyReLU(0.1,inplace=True)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
            nn.LeakyReLU(0.1,inplace=True)
        )

def predict_flow(in_planes):
    return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)

def deconv(in_planes, out_planes, ksize=3):
    return nn.Sequential(
        nn.ConvTranspose2d(in_planes, out_planes, kernel_size=ksize, stride=2, padding=1, bias=True),
        nn.LeakyReLU(0.1,inplace=True)
    )

def plot_flow(ax, flow_im, skip=1):
    """Plot flow as a set of arrows on an existing axis.
    """
    h,w,c = flow_im.shape
    bg = np.zeros((h, w, 3))
    ax.imshow(bg)
    ys, xs, _ = np.where(flow_im != 0)

    # sample instead of skip, skip param is percentage (0 - 1)
    n = len(xs)
    skip = np.clip(skip,0.0,1.0)
    inds = np.random.choice(np.arange(n), size=int(n*skip), replace=False)
    flu = flow_im[ys[inds], xs[inds], 1]
    flv = flow_im[ys[inds], xs[inds], 0]
    mags = np.linalg.norm(flow_im[ys[inds], xs[inds], :], axis=1)
    norm = mpl.colors.Normalize()
    norm.autoscale(mags)
    cm = mpl.cm.cividis

    ax.quiver(xs[inds], ys[inds], flu, flv, alpha=0.8, color=cm(norm(mags)), 
                angles='xy', scale_units='xy', scale=1, width=0.025, headwidth=5.)

class FlowPickSplit(nn.Module):
    """Main model used for picknet
    Params:
    inchannels: number of input channels
    im_w: input image size
    """
    def __init__(self, 
                 inchannels, 
                ):
        super(FlowPickSplit, self).__init__()
        self.trunk = nn.Sequential(nn.Conv2d(inchannels, 32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 1),
                                    nn.ReLU(True))
        self.head  = nn.Sequential(nn.Conv2d(32,32, 3, 1),
                                    nn.ReLU(True),
                                    nn.UpsamplingBilinear2d(scale_factor=2),
                                    nn.Conv2d(32,1, 3, 1))

        self.upsample = nn.Upsample(size=(20,20), mode="bilinear")

    def forward(self, x):
        x = self.trunk(x)
        out = self.head(x)
        out = self.upsample(out)
        return out

class FlowPickNet(nn.Module):
    """Used for 2channel (or 4 channel) ablation
    Params:
    inchannels: number of input channels
    im_w: input image size
    outchannels: number of output channels
    """
    def __init__(self, inchannels, im_w, outchannels=2, use_tanh=False, use_pool=False):
        super(FlowPickNet, self).__init__()
        self.trunk = nn.Sequential(nn.Conv2d(inchannels, 32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 2),
                                    nn.ReLU(True),
                                    nn.Conv2d(32,32, 5, 1),
                                    nn.ReLU(True))
        self.head  = nn.Sequential(nn.Conv2d(32,32, 3, 1),
                                    nn.ReLU(True),
                                    nn.UpsamplingBilinear2d(scale_factor=2),
                                    nn.Conv2d(32,outchannels, 3, 1))

        self.im_w = im_w
        self.use_tanh = use_tanh
        if use_tanh:
            self.tanh = nn.Tanh()
        self.use_pool = use_pool
        if use_pool:
            self.pool = nn.AvgPool2d(kernel_size = (20,20))

    def forward(self, x):
        x = self.trunk(x)
        out = self.head(x)
        if self.use_tanh:
            out = self.tanh(out)
        out = nn.Upsample(size=(20,20), mode="bilinear").forward(out)
        if self.use_pool:
            out = self.pool(out)

        return out

class FlowNetSmall(nn.Module):
    def __init__(self, input_channels = 12, batchNorm=True, lossnorm=2):
        super(FlowNetSmall,self).__init__()

        fs = [8, 16, 32, 64, 128] # filter sizes
        self.batchNorm = batchNorm
        self.lossnorm = lossnorm
        self.conv1   = conv(self.batchNorm, input_channels, fs[0], kernel_size=7, stride=2) # 384 -> (384 - 7 + 2*3)/2 + 1 = 377
        self.conv2   = conv(self.batchNorm, fs[0], fs[1], kernel_size=5, stride=2)
        self.conv3   = conv(self.batchNorm, fs[1], fs[2], kernel_size=5, stride=2)
        self.conv3_1 = conv(self.batchNorm, fs[2], fs[2])
        self.conv4   = conv(self.batchNorm, fs[2], fs[3], stride=2)
        self.conv4_1 = conv(self.batchNorm, fs[3], fs[3])
        self.conv5   = conv(self.batchNorm, fs[3], fs[3], stride=2)
        self.conv5_1 = conv(self.batchNorm, fs[3], fs[3])
        self.conv6   = conv(self.batchNorm, fs[3], fs[4], stride=2)
        self.conv6_1 = conv(self.batchNorm, fs[4], fs[4])

        self.deconv5 = deconv(fs[4],fs[3])
        self.deconv4 = deconv(fs[3]+fs[3]+2,fs[2])
        self.deconv3 = deconv(fs[3]+fs[2]+2,fs[1])
        self.deconv2 = deconv(fs[2]+fs[1]+2,fs[0], ksize=4)

        self.predict_flow6 = predict_flow(fs[4])
        self.predict_flow5 = predict_flow(fs[3]+fs[3]+2)
        self.predict_flow4 = predict_flow(fs[3]+fs[2]+2)
        self.predict_flow3 = predict_flow(fs[2]+fs[1]+2)
        self.predict_flow2 = predict_flow(fs[1]+fs[0]+2)

        self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 3, 2, 1, bias=False) # (H_in-1)*stride - 2*padding + (kernel-1) + 1
        self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 3, 2, 1, bias=False)
        self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 3, 2, 1, bias=False)
        self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    init.uniform_(m.bias)
                init.xavier_uniform_(m.weight)

            if isinstance(m, nn.ConvTranspose2d):
                if m.bias is not None:
                    init.uniform_(m.bias)
                init.xavier_uniform_(m.weight)
        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')

    def forward(self, x):
        out_conv1 = self.conv1(x)

        out_conv2 = self.conv2(out_conv1)
        out_conv3 = self.conv3_1(self.conv3(out_conv2))
        out_conv4 = self.conv4_1(self.conv4(out_conv3))
        out_conv5 = self.conv5_1(self.conv5(out_conv4))
        out_conv6 = self.conv6_1(self.conv6(out_conv5))

        flow6       = self.predict_flow6(out_conv6)
        flow6_up    = self.upsampled_flow6_to_5(flow6)
        out_deconv5 = self.deconv5(out_conv6)

        concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
        flow5       = self.predict_flow5(concat5)
        flow5_up    = self.upsampled_flow5_to_4(flow5)
        out_deconv4 = self.deconv4(concat5)

        concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
        flow4       = self.predict_flow4(concat4)
        flow4_up    = self.upsampled_flow4_to_3(flow4)
        out_deconv3 = self.deconv3(concat4)

        concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
        flow3       = self.predict_flow3(concat3)
        flow3_up    = self.upsampled_flow3_to_2(flow3)
        out_deconv2 = self.deconv2(concat3)

        concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
        flow2 = self.predict_flow2(concat2)

        out = self.upsample1(flow2)

        return out
    
    def loss(self, input_flow, target_flow, mask):
        b, c, h, w = input_flow.size()
        diff_flow = torch.reshape(target_flow - input_flow*mask, (b, c, h*w))
        mask = torch.reshape(mask, (b, h*w))
        norm_diff_flow = torch.linalg.norm(diff_flow, ord=self.lossnorm, dim=1) # B x 40000 get norm of flow vector diff
        mean_norm_diff_flow = norm_diff_flow.sum(dim=1) / mask.sum(dim=1) # B x 1 get average norm for each image
        batch_mean_diff_flow = mean_norm_diff_flow.mean() # mean over the batch
        return batch_mean_diff_flow


class FlowNetPickSplit(pl.LightningModule):
    """
    FlowNetSmall followed by Split PickNet
    """
    def __init__(self, cfg=None):
        super(FlowNetPickSplit,self).__init__()
        self.cfg = cfg
        self.nettype = cfg['nettype']
        self.netcfg = cfg[self.nettype]
        self.flownet = FlowNetSmall(input_channels=self.netcfg.inchannels, 
                                    batchNorm=self.netcfg.batchNorm,
                                    lossnorm = self.netcfg.lossnorm)
        self.picknet1 = FlowPickSplit(inchannels=2)
        self.picknet2 = FlowPickSplit(inchannels=3)
        if self.netcfg.predictplace:
            self.placenet1 = FlowPickSplit(inchannels=2)
            self.placenet2 = FlowPickSplit(inchannels=3)

    def forward(self, depth_input):
        flow_out = self.flownet(depth_input)
        logits1 = self.picknet1(flow_out)
        u1,v1 = self.get_pt(logits1)
        pick1_gau = self.get_gaussian(u1,v1)
        if self.netcfg.detachinput:
            x2 = torch.cat([flow_out.detach().clone(), pick1_gau.detach().clone()], dim=1)
        else:
            x2 = torch.cat([flow_out.clone(), pick1_gau.clone()], dim=1)
        logits2 = self.picknet2(x2)
        u2,v2 = self.get_pt(logits2)

        if self.netcfg.predictplace:
            logits1p = self.placenet1(flow_out)
            u1p,v1p = self.get_pt(logits1p)
            place1_gau = self.get_gaussian(u1p,v1p)
            if self.netcfg.detachinput:
                x2p = torch.cat([flow_out.detach().clone(), place1_gau.detach().clone()], dim=1)
            else:
                x2p = torch.cat([flow_out.clone(), place1_gau.clone()], dim=1)
            logits2p = self.picknet2(x2p)
            u2p,v2p = self.get_pt(logits2p)
            return flow_out, [u1, v1], [u2, v2], logits1, logits2, [u1p, v1p], [u2p, v2p], logits1p, logits2p, {'pick1_gau': pick1_gau, 'place1_gau': place1_gau}
        else:
            return flow_out, [u1, v1], [u2, v2], logits1, logits2, {'pick1_gau': pick1_gau}

    def pick_loss(self, logits1, logits2, pick1, pick2):
        """calculate loss for both picknets
        params:
        logits1: output of picknet1
        logits2: output of picknet2
        pick1: gt pickpt 1
        pick2: gt pickpt 2
        -----
        loss1: loss for picknet 1
        loss2: loss for picknet 2
        """
        N = logits1.size(0)
        W = logits1.size(2)

        pick1 = pick1.cuda()
        pick2 = pick2.cuda()
        label_a = self.get_gaussian(pick1[:,0] // 10, pick1[:,1] // 10, sigma=2, size=20)
        label_b = self.get_gaussian(pick2[:,0] // 10, pick2[:,1] // 10, sigma=2, size=20)

        if self.cfg['min_loss']:
            loss_1a = torch.mean(F.binary_cross_entropy_with_logits(logits1, label_a, reduction='none'), dim=(1,2,3))
            loss_1b = torch.mean(F.binary_cross_entropy_with_logits(logits1, label_b, reduction='none'), dim=(1,2,3))
            loss_2a = torch.mean(F.binary_cross_entropy_with_logits(logits2, label_a, reduction='none'), dim=(1,2,3))
            loss_2b = torch.mean(F.binary_cross_entropy_with_logits(logits2, label_b, reduction='none'), dim=(1,2,3))

            loss1 = torch.where((loss_1a + loss_2b) < (loss_1b + loss_2a), loss_1a, loss_1b).mean()
            loss2 = torch.where((loss_1a + loss_2b) < (loss_1b + loss_2a), loss_2b, loss_2a).mean()
        else:
            loss1 = F.binary_cross_entropy_with_logits(logits1, label_a)
            loss2 = F.binary_cross_entropy_with_logits(logits2, label_b)

        return loss1, loss2

    def loss(self, flow, flow_gt, flow_mask, 
             logits1, logits2, pick1_gt, pick2_gt, 
             logits1p=None, logits2p=None, place1_gt=None, place2_gt=None):
        # pick loss
        loss1, loss2 = self.pick_loss(logits1, logits2, pick1_gt, pick2_gt)
        loss_pick = loss1 + loss2
        info = {'loss_pick': loss_pick.clone().detach() }
        loss = loss_pick

        if logits1p is not None:
            loss1p, loss2p = self.pick_loss(logits1p, logits2p, place1_gt, place2_gt)
            loss_p = loss1p + loss2p 
            info['loss_place'] = loss_p.clone().detach()
            loss += loss_p

        # flow loss
        if self.netcfg.flowloss:
            flow_loss = self.flownet.loss(flow, flow_gt, flow_mask)
            flow_loss *= self.netcfg.flowlosswt # weight the flow loss
            info['loss_flow'] = flow_loss.clone().detach()
            loss += flow_loss

        return loss, info

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop. It is independent of forward
        depth_input = batch['depths']
        flow_gt = batch['flow_lbl']
        pick1_gt, pick2_gt = batch['pick_lbl']
        flow_mask = batch['loss_mask']
        
        if self.netcfg.predictplace:
            place1_gt, place2_gt = batch['place_lbl']
            flow, pick1, pick2, logits1, logits2, place1, place2, logits1p, logits2p, info = self.forward(depth_input)
            loss, linfo = self.loss(flow, flow_gt, flow_mask,
                             logits1, logits2, pick1_gt, pick2_gt, 
                             logits1p=logits1p, logits2p=logits2p, place1_gt=place1_gt, place2_gt=place2_gt)
            if batch_idx == 0:
                self.save_plot(depth_input, flow, flow_gt, flow_mask, 
                               pick1, pick2, pick1_gt, pick2_gt, info, 
                               place1, place2, place1_gt, place2_gt, stage="train")
        else:
            flow, pick1, pick2, logits1, logits2, info = self.forward(depth_input)
            loss, linfo = self.loss(flow, flow_gt, flow_mask,
                             logits1, logits2, pick1_gt, pick2_gt)
            if batch_idx == 0:
                self.save_plot(depth_input, flow, flow_gt, flow_mask, 
                               pick1, pick2, pick1_gt, pick2_gt, info, stage="train")

        self.log('loss/train', loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log('loss_pick/train', linfo['loss_pick'], on_step=False, on_epoch=True, prog_bar=False)
        if self.netcfg.predictplace:
            self.log('loss_place/train', linfo['loss_place'], on_step=False, on_epoch=True, prog_bar=False)
        if self.netcfg.flowloss:
            self.log('loss_flow/train', linfo['loss_flow'], on_step=False, on_epoch=True, prog_bar=False)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx, log=True):
        depth_input = batch['depths']
        flow_gt = batch['flow_lbl']
        pick1_gt, pick2_gt = batch['pick_lbl']
        flow_mask = batch['loss_mask']

        if self.netcfg.predictplace:
            place1_gt, place2_gt = batch['place_lbl']
            flow, pick1, pick2, logits1, logits2, place1, place2, logits1p, logits2p, info = self.forward(depth_input)
            loss, linfo = self.loss(flow, flow_gt, flow_mask,
                             logits1, logits2, pick1_gt, pick2_gt, 
                             logits1p=logits1p, logits2p=logits2p, place1_gt=place1_gt, place2_gt=place2_gt)
            if batch_idx == 0:
                self.save_plot(depth_input, flow, flow_gt, flow_mask, 
                               pick1, pick2, pick1_gt, pick2_gt, info, 
                               place1, place2, place1_gt, place2_gt, stage="val")
        else:
            flow, pick1, pick2, logits1, logits2, info = self.forward(depth_input)
            loss, linfo = self.loss(flow, flow_gt, flow_mask,
                            logits1, logits2, pick1_gt, pick2_gt)
            if batch_idx == 0 and log:
                self.save_plot(depth_input, flow, flow_gt, flow_mask, 
                               pick1, pick2, pick1_gt, pick2_gt, info, stage="val")            

        if log: 
            self.log('loss/val', loss)
            self.logger[1].experiment.add_histogram("flownet.conv1.weight", self.flownet.conv1[0].weight, self.global_step)
            self.logger[1].experiment.add_histogram("flownet.predict_flow2.weight", self.flownet.predict_flow2.weight, self.global_step)
            self.logger[1].experiment.add_histogram("picknet1.conv1.weight", self.picknet1.trunk[0].weight, self.global_step)
            self.logger[1].experiment.add_histogram("picknet2.conv1.weight", self.picknet2.trunk[0].weight, self.global_step)
            self.log('loss_pick/val', linfo['loss_pick'], on_step=False, on_epoch=True, prog_bar=False)
            if self.netcfg.predictplace:
                self.logger[1].experiment.add_histogram("placenet1.conv1.weight", self.placenet1.trunk[0].weight, self.global_step)
                self.logger[1].experiment.add_histogram("placenet2.conv1.weight", self.placenet2.trunk[0].weight, self.global_step)
                self.log('loss_place/val', linfo['loss_place'], on_step=False, on_epoch=True, prog_bar=False)
            if self.netcfg.flowloss:
                self.log('loss_flow/val', linfo['loss_flow'], on_step=False, on_epoch=True, prog_bar=False)
        return {'loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.wdecay)
        reduce_lr_on_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.8,
            patience=3,
            verbose=True)
        scheduler = {'scheduler': reduce_lr_on_plateau, 'monitor': 'loss/val'}
        return [optimizer], [scheduler]

    def get_pt(self, logits):
        """Get the pick point
        Parmas:
        logits: Input a logit channel
        ----
        u,v: the argmax pick point
        """
        N = logits.size(0)
        W = logits.size(2)

        probs = torch.sigmoid(logits)
        probs = probs.view(N,1,W*W)
        val,idx = torch.max(probs[:,0], 1)

        u = (idx // 20) * 10
        v = (idx % 20) * 10

        return u,v

    def get_gaussian(self, u, v, sigma=5, size=None):
        """Returns a gaussian image
        params:
        u,v : center of gaussian
        sigma: std of gaussian
        size: size of gaussian image
        -------
        g: gaussian image output
        """
        if size is None:
            size = self.cfg["im_width"]

        x0, y0 = u, v
        x0 = x0[:, None]
        y0 = y0[:, None]

        N = u.size(0)
        num = torch.arange(size).float()
        x, y = torch.vstack([num]*N).cuda(), torch.vstack([num]*N).cuda()
        gx = torch.exp(-(x-x0)**2/(2*sigma**2))
        gy = torch.exp(-(y-y0)**2/(2*sigma**2))
        g = torch.einsum('ni,no->nio', gx, gy)

        gmin = g.amin(dim=(1,2))
        gmax = g.amax(dim=(1,2))
        g = (g - gmin[:,None,None])/(gmax[:,None,None] - gmin[:,None,None])
        g = g.unsqueeze(1)

        return g

    def save_plot(self, depth_input, flow, flow_gt, flow_mask, 
                  pick1, pick2, pick1_gt, pick2_gt, info, 
                  place1=None, place2=None, place1_gt=None, place2_gt=None, 
                  stage='none', idx=0):
        # Plot the first item in the batch
        im1 = depth_input[0, 0].detach().cpu().numpy()
        im2 = depth_input[0, 1].detach().cpu().numpy()
        flow_mask = flow_mask[0].detach().squeeze().cpu().numpy()
        flow_gt = flow_gt[0].detach().permute(1, 2, 0).cpu().numpy()
        flow = flow[0].detach().permute(1, 2, 0).cpu().numpy()
        pick1_gau = info['pick1_gau'][0, 0].detach().cpu().numpy()
        pick1 = [pick1[0][0].cpu().numpy(), pick1[1][0].cpu().numpy()]
        pick2 = [pick2[0][0].cpu().numpy(), pick2[1][0].cpu().numpy()]
        pick1_gt = pick1_gt[0].cpu().numpy()
        pick2_gt = pick2_gt[0].cpu().numpy()

        fig, ax = plt.subplots(1, 5, figsize=(32, 16))
        ax[0].imshow(im1)
        ax[0].imshow(pick1_gau, alpha=0.5)
        ax[0].scatter([pick1[1]], [pick1[0]], s=100, c='blue')
        ax[0].scatter([pick2[1]], [pick2[0]], s=100, c='blue')
        ax[0].scatter([pick1_gt[1]], [pick1_gt[0]], s=100, c='green')
        ax[0].scatter([pick2_gt[1]], [pick2_gt[0]], s=100, c='green')

        if place1 is not None:
            place1_gau = info['place1_gau'][0, 0].detach().cpu().numpy()
            place1 = [place1[0][0].cpu().numpy(), place1[1][0].cpu().numpy()]
            place2 = [place2[0][0].cpu().numpy(), place2[1][0].cpu().numpy()]
            place1_gt = place1_gt[0].cpu().numpy()
            place2_gt = place2_gt[0].cpu().numpy()
            ax[0].imshow(place1_gau, alpha=0.5)
            ax[0].scatter([place1[1]], [place1[0]], s=100, c='orange')
            ax[0].scatter([place2[1]], [place2[0]], s=100, c='orange')
            ax[0].scatter([place1_gt[1]], [place1_gt[0]], s=100, c='red')
            ax[0].scatter([place2_gt[1]], [place2_gt[0]], s=100, c='red')

        ax[1].imshow(im2)
        ax[2].imshow(flow_mask)

        skip = 1
        h, w, _ = flow_gt.shape
        ax[3].imshow(np.zeros((h, w)), alpha=0.5)
        ys, xs, _ = np.where(flow_gt != 0)
        ax[3].quiver(xs[::skip], ys[::skip],
                    flow_gt[ys[::skip], xs[::skip], 1], flow_gt[ys[::skip], xs[::skip], 0], 
                    alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)

        skip = 12
        flow[flow_mask == 0, :] = 0
        h, w, _ = flow.shape
        ax[4].imshow(np.zeros((h, w)), alpha=0.5)
        ys, xs, _ = np.where(flow != 0)
        ax[4].quiver(xs[::skip], ys[::skip],
                    flow[ys[::skip], xs[::skip], 1], flow[ys[::skip], xs[::skip], 0], 
                    alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)

        plt.tight_layout()
        self.logger[1].experiment.add_figure(stage, fig, self.global_step)
        plt.close()

if __name__ == '__main__':
    f = FlowNetSmall(input_channels=2).cuda()
    print(summary(f, [(2, 200, 200)]))
