from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_butterfly.complex_utils import view_as_real, view_as_complex
import k_operation as kop
import k_matrix_conv as kmc
import numpy as np


class ButterfLeNet(nn.Module):
    def __init__(self, args, num_channels=32):
        super(ButterfLeNet, self).__init__()

        self.conv = args.conv
        self.fc = args.fc
        self.kmatrix = args.kmatrix
        self.mode = ''
        self.num_channels = num_channels

        if args.fmnist or args.mnist:
            in_ch = 1
        else:
            in_ch = 3

        kernel_size = 3
        in_size = 32

        if args.conv:
            self.C1 = nn.Conv2d(
                in_ch, num_channels, kernel_size, 
                padding=1, padding_mode='circular',
                bias=False)
            self.C3 = nn.Conv2d(
                num_channels, num_channels * 2, kernel_size, 
                padding=1, padding_mode='circular',
                bias=False)
            self.C5 = nn.Conv2d(
                num_channels * 2, num_channels * 4, kernel_size, 
                padding=1, padding_mode='circular',
                bias=False)

        elif (args.fixed or args.warm_start):
            self.C1 = kop.KOP2D(
                in_size, 
                in_ch, num_channels, kernel_size, 
                nblocks=args.depth, warm_start=True,
                padding=1, stride=1).cuda()
            self.C3 = kop.KOP2D(
                in_size // 2, 
                num_channels, num_channels * 2, kernel_size, 
                nblocks=args.depth, warm_start=True,
                padding=1, stride=1,).cuda()
                #K1=self.C1.K1, Kd=self.C1.Kd, K2=self.C1.K2).cuda()
            self.C5 = kop.KOP2D(
                in_size // 2, #4, 
                num_channels * 2, num_channels * 4, kernel_size, 
                nblocks=args.depth, warm_start=True,
                padding=1, stride=1,
                K1=self.C3.K1, Kd=self.C3.Kd, K2=self.C3.K2).cuda()
            
            self.mode = 'warm_start'

        else: # TODO
            self.C1 = kop.KOP2D(
                in_size, 
                in_ch, num_channels, kernel_size, 
                nblocks=args.depth, warm_start=False,
                padding=1, stride=1).cuda()
            self.C3 = kop.KOP2D(
                in_size,# // 2, 
                num_channels, num_channels * 2, kernel_size, 
                nblocks=args.depth, warm_start=False,
                padding=1, stride=1,).cuda()
                #K1=self.C1.K1, Kd=self.C1.Kd, K2=self.C1.K2).cuda()
            self.C5 = kop.KOP2D(
                in_size,# // 2, #4, 
                num_channels * 2, num_channels * 4, kernel_size, 
                nblocks=args.depth, warm_start=False,
                padding=1, stride=1,#).cuda()
                K1=self.C3.K1, Kd=self.C3.Kd, K2=self.C3.K2).cuda()

            self.mode = 'from_scratch'

        self.F6 = nn.Linear(4 * 4 * num_channels * 4, num_channels * 4)

        if args.cifar100:
            self.OUTPUT = nn.Linear(num_channels * 4, 100)
        else:
            self.OUTPUT = nn.Linear(num_channels * 4, 10)

        self.BN1 = nn.BatchNorm2d(self.num_channels)
        self.BN2 = nn.BatchNorm2d(self.num_channels * 2)
        self.BN3 = nn.BatchNorm2d(self.num_channels * 4)
        self.fcbn1 = nn.BatchNorm1d(self.num_channels * 4)

        if args.searchdir != "" and not args.conv:
            print("Loading architecture weights from\n", args.searchdir)
            try:
                state_dict = torch.load(
                    args.searchdir + f"/models/model_{args.loadepoch}.pt")
            except:
                state_dict = torch.load(
                    args.searchdir + f"/models/model_{args.epochs}.pt")
            with torch.no_grad():
                self.load_arch(state_dict)

    def load_arch(self, state_dict):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if "twiddle" in name:
                if isinstance(param, nn.Parameter):
                    param = param.data
                own_state[name].copy_(param)

    def forward(self, x):
        x = self.C1(x)
        x = self.BN1(x)
        x = torch.relu(F.max_pool2d(x, 2))

        x = self.C3(x)
        x = self.BN2(x)
        x = torch.relu(F.max_pool2d(x, 2))

        x = self.C5(x)
        x = self.BN3(x)
        x = torch.relu(F.max_pool2d(x, 2))

        x = x.view(-1, 4*4*self.num_channels*4)
        x = self.F6(x)
        x = self.fcbn1(x)
        x = torch.relu(x)
        x = self.OUTPUT(x)
        
        output = F.log_softmax(x, dim=1)
        return output

    def est_norms(self):
        pass

    def get_fnorms_butterfly(self, grad=False):
        fnorms = []

        def twiddlenorm(t):
            with torch.no_grad():
                if grad:
                    if t.grad is None:
                        return 0.0
                    tc = view_as_complex(t.grad)
                else:
                    tc = view_as_complex(t)
                norm = (tc * tc.conj()).sum().sqrt()
                norm = view_as_real(norm)[0]
                norm = norm.item()
                return norm

        if not self.conv and not self.fc and not self.kmatrix:
            fnorms.append(twiddlenorm(self.C1.K1.map1.twiddle))
            fnorms.append(twiddlenorm(self.C1.K1.map2.twiddle))

            fnorms.append(twiddlenorm(self.C1.Kd.map1.twiddle))
            fnorms.append(twiddlenorm(self.C1.Kd.map2.twiddle))

            fnorms.append(twiddlenorm(self.C1.K2.map1.twiddle))
            fnorms.append(twiddlenorm(self.C1.K2.map2.twiddle))

            fnorms.append(twiddlenorm(self.C3.K1.map1.twiddle))
            fnorms.append(twiddlenorm(self.C3.K1.map2.twiddle))

            fnorms.append(twiddlenorm(self.C3.Kd.map1.twiddle))
            fnorms.append(twiddlenorm(self.C3.Kd.map2.twiddle))

            fnorms.append(twiddlenorm(self.C3.K2.map1.twiddle))
            fnorms.append(twiddlenorm(self.C3.K2.map2.twiddle))
        return np.array(fnorms)

    def get_fnorms(self):
        fnorms = []
        op1 = []
        op2 = []

        #if not self.conv and not self.fc and not self.kmatrix:
        if False: # TODO don't care about this right now
            with torch.no_grad():
                chan = 3
                dim = 32

                size = chan * dim * dim
                iden = torch.eye(size).reshape(size, chan, dim, dim).cuda()
                operator = self.C1(iden)
                fnorms.append(torch.sqrt(torch.sum(operator ** 2)).item())

                op1 = operator.cpu().numpy()

                chan = 6
                dim = 16

                size = chan * dim * dim
                iden = torch.eye(size).reshape(size, chan, dim, dim).cuda()
                operator = self.C3(iden)
                fnorms.append(torch.sqrt(torch.sum(operator ** 2)).item())

                op2 = operator.cpu().numpy()

                def twiddlenorm(t):
                    tc = view_as_complex(t)
                    norm = (tc * tc.conj()).sum().sqrt()
                    norm = view_as_real(norm)[0]
                    norm = norm.item()
                    return norm

                fnorms.append(twiddlenorm(self.C1.K1.map1.twiddle))
                fnorms.append(twiddlenorm(self.C1.K1.map2.twiddle))

                fnorms.append(twiddlenorm(self.C1.Kd.map1.twiddle))
                fnorms.append(twiddlenorm(self.C1.Kd.map2.twiddle))

                fnorms.append(twiddlenorm(self.C1.K2.map1.twiddle))
                fnorms.append(twiddlenorm(self.C1.K2.map2.twiddle))

                fnorms.append(twiddlenorm(self.C3.K1.map1.twiddle))
                fnorms.append(twiddlenorm(self.C3.K1.map2.twiddle))

                fnorms.append(twiddlenorm(self.C3.Kd.map1.twiddle))
                fnorms.append(twiddlenorm(self.C3.Kd.map2.twiddle))

                fnorms.append(twiddlenorm(self.C3.K2.map1.twiddle))
                fnorms.append(twiddlenorm(self.C3.K2.map2.twiddle))

        return np.array(fnorms), np.array(op1), np.array(op2)

    def kop_reg(self, k=250, approx=True):
        norms = 0

        if not self.conv:
            if approx:
                v1 = (self.C1(torch.randn(
                    k, 3, 32, 32).cuda()) ** 2).reshape(k, -1)

                v3 = (self.C3(torch.randn(
                    k, 6, 16, 16).cuda()) ** 2).reshape(k, -1)

                norms = (torch.sum(v1, dim=1).mean() 
                    + torch.sum(v3, dim=1).mean())
            else:
                chan = 3
                dim = 32
                size = chan * dim * dim
                iden = torch.eye(size).reshape(size, chan, dim, dim).cuda()
                op1 = self.C1(iden)

                chan = 6
                dim = 16
                size = chan * dim * dim
                iden = torch.eye(size).reshape(size, chan, dim, dim).cuda()
                op2 = self.C3(iden)

                norms = ((torch.norm(op1.flatten(), p=2) ** 2) 
                    + (torch.norm(op2.flatten(), p=2) ** 2))

        return norms
