import sys
import os

from torch.nn.modules.activation import ELU

import math
import torch
import torchvision
import torch.nn as nn
import numpy as np


from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torchvision import models

class PredictorAbstract(nn.Module):
    def __init__(self, input_size, output_size):
        super(PredictorAbstract, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        if type(self.output_size) == int:
            self.output = output_size
            if self.output > 1:
                self.activation = nn.LogSoftmax(-1)
            else :
                self.activation = lambda x: x
        else:
            self.output = np.prod(self.output_size)
            if len(self.output_size) >1:
                raise NotImplementedError
            self.activation = lambda x : x
        
    def forward(self, x):
        raise NotImplementedError
    



def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

class RealXClassifier(PredictorAbstract):
    def __init__(self, input_size, output_size, middle_size=200):
        super().__init__(input_size=input_size, output_size=output_size)
        self.input_size = input_size
        self.fc1 = nn.Linear(np.prod(input_size), 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, self.output)

        self.fc1.apply(init_weights)
        self.fc2.apply(init_weights)
        self.fc3.apply(init_weights)

    def __call__(self, x):
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return self.activation(x)

class StupidClassifier(PredictorAbstract):
    def __init__(self, input_size = (1,28,28),output_size = 10, bias = True):
        super().__init__(input_size=input_size, output_size=output_size)
        self.bias = bias
        self.input_size = input_size
        self.fc1 = nn.Linear(np.prod(input_size), self.output, bias = bias)
        
        self.elu = nn.ELU()


    def __call__(self, x):
        x = x.flatten(1)
        return self.activation(self.elu(self.fc1(x)))




class ConvClassifier2(PredictorAbstract):
    def __init__(self, input_size = (1,28,28), output_size = 10):
        super().__init__(input_size=input_size, output_size=output_size)
        self.nb_block = int(math.log(min(self.input_size[1], self.input_size[2]), 2)//2)
        
        liste_conv = []
        liste_conv.extend([
            nn.Conv2d(input_size[0], 2**5, 3, stride=1, padding=1),
            nn.Conv2d(2**5, 2**5, 3, stride=1, padding=1),
            nn.AvgPool2d(kernel_size=(2,2),stride=2,padding = 0)
        ])
        for k in range(1, self.nb_block):
            liste_conv.extend([
                nn.Conv2d(2**(k+4), 2**(k+5), 3, stride=1, padding=1),
                nn.Conv2d(2**(k+5), 2**(k+5), 3, stride=1, padding=1),
                nn.AvgPool2d(kernel_size=(2,2),stride=2,padding = 0),
            ]
            )
        self.conv = nn.ModuleList(liste_conv)
        last_channel = 2**(self.nb_block+4)
        last_size = int(np.prod(input_size[1:])/(2**(2*self.nb_block)))
        self.fc = nn.Linear(last_channel*last_size,128)

        self.elu = nn.ELU()

        self.fc2 = nn.Linear(128,self.output)


    
    def __call__(self, x):
        batch_size = x.shape[0]
        for k in range(len(self.conv)):
            x = self.conv[k](x)
        x = x.flatten(1)
        x = self.elu(self.fc(x))
        x = self.activation(self.fc2(x))
        return x 


            

class ResNet50(PredictorAbstract):
    def __init__(self, input_size = (3, 224, 224), output_size = 10):
        super().__init__(input_size=input_size, output_size=output_size)

        self.model = models.resnet50(pretrained=True)
        self.model.fc = nn.Linear(512, self.output)

    def __call__(self, x):
        x = self.model(x)
        x = self.activation(x)
        return x

class ResNet34(PredictorAbstract):
    def __init__(self, input_size = (3, 224, 224), output_size = 10):
        super().__init__(input_size=input_size, output_size=output_size)

        self.model = models.resnet34(pretrained=True)
        self.model.fc = nn.Linear(512, self.output)

    def __call__(self, x):
        x = self.model(x)
        x = self.activation(x)
        return x
