from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_butterfly.complex_utils import view_as_real, view_as_complex
import k_operation as kop
import k_matrix_conv as kmc
import numpy as np


class ButterfLeNet(nn.Module):
    def __init__(self, args):
        super(ButterfLeNet, self).__init__()

        self.conv = args.conv
        self.fc = args.fc
        self.kmatrix = args.kmatrix
        self.bn = args.bn
        self.mode = ''

        if args.fmnist or args.mnist:
            in_ch = 1
        else:
            in_ch = 3

        kernel_size = 5
        in_size = 32

        if args.fc:
            self.C1 = nn.Linear(
                in_ch * in_size * in_size,
                6 * (in_size // 2) * (in_size // 2), bias=False)
            self.C3 = nn.Linear(
                6 * (in_size // 2) * (in_size // 2), 
                16 * (in_size // 4) * (in_size // 4), bias=False)

        elif args.conv:
            self.C1 = nn.Conv2d(in_ch, 6, kernel_size, 
                padding=(kernel_size - 1) // 2, padding_mode='circular',
                bias=False)
            self.C3 = nn.Conv2d(6, 16, kernel_size, 
                padding=(kernel_size - 1) // 2, padding_mode='circular',
                bias=False)

        elif args.kmatrix:
            self.C1 = kmc.KMatrix(in_size, in_ch, 6, kernel_size).cuda()
            self.C3 = kmc.KMatrix(16, 6, 16, kernel_size).cuda()

        elif (args.fixed or args.warm_start):
            self.C1 = kop.KOP2D(in_size, in_ch, 6, kernel_size, 
                nblocks=args.depth, warm_start=True,
                padding=(kernel_size - 1) // 2, stride=1).cuda()
            if args.tied:
                self.C3 = kop.KOP2D(in_size, 6, 16, kernel_size, 
                    nblocks=args.depth, warm_start=True,
                    padding=(kernel_size - 1) // 2, stride=1,
                    K1=self.C1.K1, Kd=self.C1.Kd, K2=self.C1.K2).cuda()
            else:
                self.C3 = kop.KOP2D(16, 6, 16, kernel_size, 
                    nblocks=args.depth, warm_start=True,
                    padding=(kernel_size - 1) // 2, stride=1).cuda()
            self.mode = 'warm_start'

        else:
            self.C1 = kop.KOP2D(in_size, in_ch, 6, kernel_size, 
                nblocks=args.depth, warm_start=False,
                padding=(kernel_size - 1) // 2, stride=1).cuda()
            if args.tied:
                self.C3 = kop.KOP2D(in_size, 6, 16, kernel_size, 
                    nblocks=args.depth, warm_start=False,
                    padding=(kernel_size - 1) // 2, stride=1,
                    K1=self.C1.K1, Kd=self.C1.Kd, K2=self.C1.K2).cuda()
            else:
                self.C3 = kop.KOP2D(16, 6, 16, kernel_size, 
                    nblocks=args.depth, warm_start=False,
                    padding=(kernel_size - 1) // 2, stride=1).cuda()
            self.mode = 'from_scratch'

        self.C5 = nn.Linear(1024, 120)
        torch.nn.init.kaiming_normal_(self.C5.weight, nonlinearity='relu')

        self.S2 = nn.AvgPool2d(2, stride=2)
        self.S4 = nn.AvgPool2d(2, stride=2)

        self.F6 = nn.Linear(120, 84)
        torch.nn.init.kaiming_normal_(self.F6.weight, nonlinearity='relu')

        if args.cifar100:
            self.OUTPUT = nn.Linear(84, 100)
        else:
            self.OUTPUT = nn.Linear(84, 10)

        torch.nn.init.kaiming_normal_(self.OUTPUT.weight, nonlinearity='relu')

        self.BN1 = nn.BatchNorm2d(6)
        self.BN2 = nn.BatchNorm2d(16)

        if args.searchdir != "" and not args.conv:
            print("Loading architecture weights from\n", args.searchdir)
            try:
                state_dict = torch.load(
                    args.searchdir + f"/models/model_{args.loadepoch}.pt")
            except:
                state_dict = torch.load(
                    args.searchdir + f"/models/model_{args.epochs}.pt")
            with torch.no_grad():
                self.load_arch(state_dict)

    def load_arch(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if "twiddle" in name:
                if isinstance(param, nn.Parameter):
                    param = param.data
                own_state[name].copy_(param)

    def forward(self, x):
        if self.fc:
            x = torch.flatten(x, 1)

        x = self.C1(x)

        x = torch.relu(x)

        if not self.fc:
            x = self.S2(x)

        x = self.C3(x)

        x = torch.relu(x)

        if not self.fc:
            x = self.S4(x)
            x = torch.flatten(x, 1)

        x = self.C5(x)
        x = torch.relu(x)
        x = self.F6(x)
        x = torch.relu(x)
        x = self.OUTPUT(x)
        output = F.log_softmax(x, dim=1)
        return output

    def est_norms(self):
        pass

    def get_fnorms_butterfly(self, grad=False):
        fnorms = []

        def twiddlenorm(t):
            with torch.no_grad():
                if grad:
                    if t.grad is None:
                        return 0.0
                    tc = view_as_complex(t.grad)
                else:
                    tc = view_as_complex(t)
                norm = (tc * tc.conj()).sum().sqrt()
                norm = view_as_real(norm)[0]
                norm = norm.item()
                return norm

        if not self.conv and not self.fc and not self.kmatrix:
            fnorms.append(twiddlenorm(self.C1.K1.map1.twiddle))
            fnorms.append(twiddlenorm(self.C1.K1.map2.twiddle))

            fnorms.append(twiddlenorm(self.C1.Kd.map1.twiddle))
            fnorms.append(twiddlenorm(self.C1.Kd.map2.twiddle))

            fnorms.append(twiddlenorm(self.C1.K2.map1.twiddle))
            fnorms.append(twiddlenorm(self.C1.K2.map2.twiddle))

            fnorms.append(twiddlenorm(self.C3.K1.map1.twiddle))
            fnorms.append(twiddlenorm(self.C3.K1.map2.twiddle))

            fnorms.append(twiddlenorm(self.C3.Kd.map1.twiddle))
            fnorms.append(twiddlenorm(self.C3.Kd.map2.twiddle))

            fnorms.append(twiddlenorm(self.C3.K2.map1.twiddle))
            fnorms.append(twiddlenorm(self.C3.K2.map2.twiddle))
        return np.array(fnorms)

    def get_fnorms(self):
        fnorms = []
        op1 = []
        op2 = []

        if not self.conv and not self.fc and not self.kmatrix:

            with torch.no_grad():
                chan = 3
                dim = 32

                size = chan * dim * dim
                iden = torch.eye(size).reshape(size, chan, dim, dim).cuda()
                operator = self.C1(iden)
                fnorms.append(torch.sqrt(torch.sum(operator ** 2)).item())

                op1 = operator.cpu().numpy()

                chan = 6
                dim = 16

                size = chan * dim * dim
                iden = torch.eye(size).reshape(size, chan, dim, dim).cuda()
                operator = self.C3(iden)
                fnorms.append(torch.sqrt(torch.sum(operator ** 2)).item())

                op2 = operator.cpu().numpy()

                def twiddlenorm(t):
                    tc = view_as_complex(t)
                    norm = (tc * tc.conj()).sum().sqrt()
                    norm = view_as_real(norm)[0]
                    norm = norm.item()
                    return norm

                fnorms.append(twiddlenorm(self.C1.K1.map1.twiddle))
                fnorms.append(twiddlenorm(self.C1.K1.map2.twiddle))

                fnorms.append(twiddlenorm(self.C1.Kd.map1.twiddle))
                fnorms.append(twiddlenorm(self.C1.Kd.map2.twiddle))

                fnorms.append(twiddlenorm(self.C1.K2.map1.twiddle))
                fnorms.append(twiddlenorm(self.C1.K2.map2.twiddle))

                fnorms.append(twiddlenorm(self.C3.K1.map1.twiddle))
                fnorms.append(twiddlenorm(self.C3.K1.map2.twiddle))

                fnorms.append(twiddlenorm(self.C3.Kd.map1.twiddle))
                fnorms.append(twiddlenorm(self.C3.Kd.map2.twiddle))

                fnorms.append(twiddlenorm(self.C3.K2.map1.twiddle))
                fnorms.append(twiddlenorm(self.C3.K2.map2.twiddle))

        return np.array(fnorms), np.array(op1), np.array(op2)

    def kop_reg(self, k=250, approx=True):
        norms = 0

        if not self.conv:
            if approx:
                v1 = (self.C1(torch.randn(
                    k, 3, 32, 32).cuda()) ** 2).reshape(k, -1)

                v3 = (self.C3(torch.randn(
                    k, 6, 16, 16).cuda()) ** 2).reshape(k, -1)

                norms = (torch.sum(v1, dim=1).mean() 
                    + torch.sum(v3, dim=1).mean())
            else:
                chan = 3
                dim = 32
                size = chan * dim * dim
                iden = torch.eye(size).reshape(size, chan, dim, dim).cuda()
                op1 = self.C1(iden)

                chan = 6
                dim = 16
                size = chan * dim * dim
                iden = torch.eye(size).reshape(size, chan, dim, dim).cuda()
                op2 = self.C3(iden)

                norms = ((torch.norm(op1.flatten(), p=2) ** 2) 
                    + (torch.norm(op2.flatten(), p=2) ** 2))

        return norms
