from torch import nn
import torch.nn.functional as F
import torch
import math
class Net(nn.Module):

    def __init__(self, hdim = 10, odim=10, ch=1, dataset = "CIFAR10"):
        super(Net,self).__init__()

        self.hdim = hdim
        self.ch = ch

        if dataset == "MNIST":
            self.conv1 = nn.Conv2d(1,ch,kernel_size=5,stride=1) #ch x 24 x 24
        elif dataset == "CIFAR10":
            self.conv1 = nn.Conv2d(3,ch,kernel_size=9,stride=1) #ch x 24 x 24
        self.conv1bn=nn.BatchNorm2d(ch)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2) #ch x 12 x 12
        self.conv2 = nn.Conv2d(ch,ch,kernel_size=5,stride=1) #ch x 8 x 8
        self.conv2bn=nn.BatchNorm2d(ch) 
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2) #ch x 4 x 4
        self.fc1 = nn.Linear(4 * 4 * ch, hdim) #hdim
        self.fc1bn=nn.BatchNorm1d(hdim)                
        self.fc2 = nn.Linear(hdim,odim) #odim
        self.fc2bn=nn.BatchNorm1d(odim)                
        
        # print('Parameters:',sum(p.numel() for p in self.parameters()))
        # print('Learnable Parameters:',sum(p.numel()*p.requires_grad for p in self.parameters()))

    def normalized_para(self):
        with torch.no_grad():
            R = self.paranorm()
            x = torch.empty(0)
            for p in self.parameters():
                if p.requires_grad == True:
                    p1 = p/R
                    x = torch.cat((x,p.view(-1)))
                    
    def gradnorm(self):
        S=0.
        with torch.no_grad():
            for p in self.parameters():
                if p.requires_grad == True:
                    S+=(p.grad*p.grad).sum()
        return S**.5

    def paranorm(self):
        S=0.
        with torch.no_grad():
            for p in self.parameters():
                if p.requires_grad == True:
                    S+= (p*p).sum()
        return S**.5

    def compare(self,other):
        S=0.
        with torch.no_grad():
            for p1, p2 in zip(self.parameters(), other.parameters()):
                S+=((p1-p2)*(p1-p2)*p1.requires_grad).sum()
        return S**.5

    def add(self,other):
        with torch.no_grad():
            for p, p1 in zip(self.parameters(), other.parameters()):
                p_new = p+p1
                p.copy_(p_new)

    def mul(self, ratio = 1):
        with torch.no_grad():
            for p in self.parameters():
                if p.requires_grad == True:
                    p1 = ratio * p
                    p.copy_(p1)
                
    def flow(self,lr = 0,wd = 0):
        with torch.no_grad():
            for p in self.parameters():
                if p.requires_grad == True:
                    p1 = (1-wd)*p-lr*(p.grad)
                    p.copy_(p1)
                
    def seq(self,x):
        x = F.relu(self.conv1bn(self.conv1(x))) #24x24xch
        x = self.pool1(x) #12x12xch
        x = F.relu(self.conv2bn(self.conv2(x))) #8x8xch
        x = self.pool2(x) #4x4xch
        x = x.view(-1, 4*4*self.ch) #flattening
        x = F.relu(self.fc1bn(self.fc1(x))) #hdim
        x = self.fc2bn(self.fc2(x)) #odim
        return x

    def state(self):
        x = torch.empty(0)
        pushed = False
        for p in self.parameters():
            if pushed == False:
                x= x.to(p.device)
                pushed = True
            if p.requires_grad == True:
                x = torch.cat((x,p.view(-1)))
        return x
  
    def forward(self,X,target,Criterion):
        bs=X.size()[0]

        h=self.seq(X)
        target=target.long()

        L_total=Criterion(h,target)
                    #nn.CrossEntropyLoss is the sum of CrossEntropyLoss for y(index=1..batchsize),
                    #and pred(1..batchsize) divided by batchsize
        A_total=(torch.max(h,1).indices==target).float()
        
        L=L_total.sum()
        A=A_total.sum()/bs
        return L,A,
            
def run(net,X,y,Criterion,lr,wd,train=True):
    
    G=torch.zeros(1)
    
    if train==True:
        L,A=net(X,y,Criterion)
        L.backward(retain_graph=True)
        G = net.gradnorm()
        net.flow(lr,wd)
        # net.normalize()
        net.zero_grad()
        
    if train==False:
        with torch.no_grad():
            L,A = net(X,y,Criterion)

    return L.detach(), A.detach(), G.detach()

def Run(net,DL,Criterion,lr,wd, train=True):

    Loss, Accu, Grad= 0, 0, 0
    for batchcounter,(X,y) in enumerate(DL):
        L, A, G=run(net,X,y,Criterion,lr,wd,train)
        Loss+=L.item()
        Accu+=A.item()
        Grad+=G.item()
    Loss /= len(DL)
    Accu /= len(DL)
    Grad /= len(DL)
    return Loss, Accu, Grad

