from Libraries import *
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
criterion = nn.CrossEntropyLoss().to(device)
workers = 8 if device == 'cuda' else 0
use_gpu = True if device == 'cuda' else False

###########################################################Networks definitions##############################################################
class Net(nn.Module): #Single hidden layer NN
    def __init__(self,net_size,input_size,output_size,Bias):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, net_size,bias = Bias)
        self.fc2 = nn.Linear(net_size,output_size,bias = Bias)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.fc2(self.relu(self.fc1(x)))    
        return x

    
class MultiNet(nn.Module):
    def __init__(self,net_size1,net_size2,input_size,output_size,Bias):
        super(MultiNet, self).__init__()
        self.fc1 = nn.Linear(input_size, net_size1,bias = Bias)
        self.fc2 = nn.Linear(net_size1,net_size2,bias = Bias)
        self.fc3 = nn.Linear(net_size2,output_size,bias = Bias)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    


class LeNet(nn.Module):
    '''
    A class used to implement LeNet architecture.
    For Reference : http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf
    '''
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 10)
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class AlexNet(nn.Module):
    '''
    A class used to implement AlexNet architecture.
    '''
    def __init__(self, in_channel, n_class):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, 96, kernel_size=11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1*1*256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )
    def forward(self, x):
        con1_x = self.conv1(x)
        con2_x = self.conv2(con1_x)
        con3_x = self.conv3(con2_x)
        lin_x = con3_x.view(con3_x.size(0), -1)
        y_hat = self.fc(lin_x)
        return y_hat
##################################################################Training and validation Functions######################################################   

def train_net(net,loader,val_loader,criterion,update = 'bias',lrate = 0.01,max_epochs=6,architecture='vgg16'):
    
    """Returns a trained network

    Parameters
    ----------
    net : PyTorch Model
        A model that we want to train
    loader : PyTorch Dataloader
        A PyTorch trainset dataloader
    val_loader : PyTorch Dataloader
        A PyTorch testset dataloader
    criterion : PyTorch Loss Criterion
        The loss criterion used
    update: str
        A string to specify whether we want to train FC Biases , All Network Biases, or Classifier and FC Biases
    lrate: float
        A float to specify the initial learning rate
    max_ecpohs: int
        An int to set the number of epochs used to do the training
    architecture: str
        A string to specify the architecture we are training (This is used since different models have different naming for the FC Layer)
    Returns
    -------
    PyTorch Model
        a trained model of the required architecture and type of training
    """
   
    if architecture == 'vgg16':
        setting_title = 'classifier'
    else: 
        setting_title = 'fc'
    if update == 'allbias':
        for param in net.parameters():
            param.requires_grad = False
        for k, v in net.named_parameters():
            if 'bias' in k :
                v.requires_grad_(True)
        optimizer = radam.RAdam(net.parameters(), lr=lrate)
    
    if update == 'fcbias':
        for param in net.parameters():
            param.requires_grad = False
        for k, v in net.named_parameters():
            if 'bias' in k and setting_title in k:
                v.requires_grad_(True)
        optimizer = radam.RAdam(net.parameters(), lr=lrate)

    if update == 'all':
        for k, v in net.named_parameters():
            if 'weight' in k and setting_title in k:
                v.requires_grad=False
            else:
                v.requires_grad = True
        optimizer = radam.RAdam(net.parameters(), lr=lrate)

    count = 0

    for epoch in range(max_epochs):
        net.train()
        tic = time()
        for inp, target in loader:
            count += 1
            inp,target = inp.to(device),target.to(device)
            optimizer.zero_grad()
            out = net(inp)
            loss = criterion(out,target)
            loss.backward()
            optimizer.step()
            lrate = lrate/3 if epoch % 2 == 0 else lrate
    return net

def validate(net, val_loader):
    """Returns a network's accuracy on test set.

    Parameters
    ----------
    net : PyTorch Model
        A model that we want to train
    val_loader : PyTorch Dataloader

    Returns
    -------
    accuracy
        The accuracy of the network on the test set
    """
    net.eval()
    num_correct = 0
    num_examples = 0
    for inp, target in val_loader:
        out = net(inp.to(device)).argmax(1)
        correct = torch.eq(out, target.to(device))
        num_correct += torch.sum(correct).item()
        num_examples += correct.shape[0]
    accuracy = num_correct/num_examples *100
    return accuracy

