import datetime
import os, sys
import random
import argparse
import numpy as np

from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch
from torch import nn

import torch.nn.functional as F
import pdb

import time

from torch.nn.parameter import Parameter
import math



class Conv_DCFr(nn.Module):
    r"""Pytorch implementation for 2D DCF Convolution operation.
    Link to ICML paper:
    https://arxiv.org/pdf/1802.04145.pdf

    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int): Size of the convolving kernel
        stride (int, optional): Stride of the convolution. Default: 1
        padding (int, optional): Zero-padding added to both sides of
            the input. Default: 0
        num_bases (int, optional): Number of basis elements for decomposition.
        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
        mode (optional): Either `mode0` for two-conv or `mode1` for reconstruction + conv.

    Shape:
        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where

          .. math::
              H_{out} = \left\lfloor\frac{H_{in}  + 2 * \text{padding}[0] - \text{dilation}[0]
                        * (\text{kernel_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

              W_{out} = \left\lfloor\frac{W_{in}  + 2 * \text{padding}[1] - \text{dilation}[1]
                        * (\text{kernel_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

    Attributes:
        weight (Tensor): the learnable weights of the module of shape
                         (out_channels, in_channels, kernel_size, kernel_size)
        bias (Tensor):   the learnable bias of the module of shape (out_channels)

    Examples::
        
        >>> from DCF import *
        >>> m = Conv_DCF(16, 33, 3, stride=2)
        >>> input = torch.randn(20, 16, 50)
        >>> output = m(input)

    """
    __constants__ = ['kernel_size', 'stride', 'padding', 'num_bases']
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 
                    num_bases=-1, bias=False, dilation=1):
        super(Conv_DCFr, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.kernel_list = {}
        self.num_bases = num_bases
        self.dilation = dilation

        # set parameters as coefficients of bases, with shape [chn_out, num_bases*chn_in, 1, 1]
        self.weight = Parameter(torch.Tensor(num_bases, out_channels, in_channels))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.shape[2] * self.weight.shape[0])
        self.weight.data.normal_(0, stdv) #Normal works better, working on more robust initializations
        if self.bias is not None:
            self.bias.data.zero_()
        
    
    def forward(self, input):
        """
            input = (input, computed_bases)
        """
        input, bases = input
        K = self.num_bases
        coeff = self.weight.view(K, -1)

        rec_kernel = torch.matmul(bases, coeff).\
                    view(self.kernel_size, self.kernel_size, self.out_channels, self.in_channels)
        rec_kernel = rec_kernel.permute(2, 3, 0, 1).contiguous()

        # conv
        feature = F.conv2d(input, rec_kernel,
            self.bias, self.stride, self.padding, dilation=self.dilation)
        
        return feature

    def extra_repr(self):
        return 'kernel_size={kernel_size}, stride={stride}, padding={padding}, num_bases={num_bases}'.format(**self.__dict__)