import torch
import time
from tqdm import tqdm
from torch.nn import Module
from torch.nn import Conv2d, BatchNorm2d, Linear, BatchNorm1d
from torch.nn import ReLU
from torch.nn import LogSoftmax, Sigmoid, Softmax
from torch.nn import ModuleList, Sequential
from torch import count_nonzero
from torch.nn import CrossEntropyLoss, MSELoss, NLLLoss
from torch.nn.functional import one_hot
from torch.optim import SGD, Adam
from einops import rearrange
import numpy as np
from torch.nn import functional as F



class SingleStepModule(Module):

    def __init__(self, channels, extra_depth=0, use_bn=True):
        """
        Initializes a single step module with convolutional layers, batch normalization, and activation.

        Args:
            channels (int): Number of input and output channels for the convolutional layers.
            extra_depth (int): Number of additional 1x1 convolutional layers to add.
        """
        super(SingleStepModule, self).__init__()

        self.channels = channels
        self.extra_depth = extra_depth

        self.layers = []
        self.layers.append(Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1, padding_mode='circular'))
        if use_bn:
            self.layers.append(BatchNorm2d(channels))
        self.layers.append(torch.nn.LeakyReLU())

        for i in range(extra_depth):
            self.layers.append(Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, padding=0, padding_mode='circular'))
            if use_bn:
                self.layers.append(BatchNorm2d(channels))
            self.layers.append(torch.nn.LeakyReLU())

        self.layers = Sequential(*self.layers)

    def forward(self, x):
        """
        Forward pass through the single step module.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).

        Returns:
            torch.Tensor: Output tensor after applying the layers.
        """
        return self.layers(x)
    

class CA_CNN_Convolutional(Module):

    def __init__(self, width, num_iterations, num_classes, extra_depth=0, residual=True, use_bn=True, device='cuda'):
        """
        Initializes the CA_CNN_Convolutional model.

        Args:
            width (int): Number of channels in the intermediate layers.
            num_iterations (int): Number of iterations (steps) in the model.
            num_classes (int): Number of output classes.
            extra_depth (int): Number of additional layers in each step.
            residual (bool): Whether to use residual connections.
        """
        super(CA_CNN_Convolutional, self).__init__()
        self.use_bn = use_bn
        self.num_iterations = num_iterations
        self.num_classes = num_classes
        self.residual = residual
        
        self.input_emb = (Conv2d(in_channels=num_classes, out_channels=width, kernel_size=1, padding=0, padding_mode='circular'))

        self.layers = torch.nn.ModuleList()

        for i in range(num_iterations):
            self.layers.append(SingleStepModule(width, extra_depth=extra_depth, use_bn=use_bn))
        
        self.output_emb = (Conv2d(in_channels=width, out_channels=num_classes, kernel_size=1, padding=0, padding_mode='circular'))

        self.to(device)

    def forward(self, x, labels=None):
        """
        Forward pass through the CA_CNN_Convolutional model.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, num_classes, height, width).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, num_classes, height, width).
        """
        x = self.input_emb(x)
        for i in range(self.num_iterations):
            if self.residual:
                x = x + self.layers[i](x)
            else:
                x = self.layers[i](x)
        x = self.output_emb(x)

        if labels is not None:
            accuracy = torch.mean(1.0*(torch.argmax(x, dim=1) == labels))
            return x, F.cross_entropy(x, labels), accuracy.detach().cpu().numpy().item()

        return x, 

    def calc_perturbation_sensitivity(self, num_samples=256, device='cuda'):
        """
        Calculates the sensitivity of the model to perturbations in the input.

        Args:
            num_samples (int): Number of samples to use for the sensitivity calculation.

        Returns:
            tuple: Sensitivity value (float) and the total number of elements considered (int).
        """
        size = self.num_iterations*2+1

        x_input = torch.randint(0, self.num_classes, (num_samples, size, size))
        x_input = one_hot(x_input.to(torch.int64), num_classes=self.num_classes)
        x_input =x_input.permute(0, 3, 1, 2)
        x_input = x_input.to(torch.float32).to(device=device)

        #change center value in clone
        x_input_clone = x_input.clone()
        x_input_clone[:, :, self.num_iterations, self.num_iterations] = 1-x_input[:, :, self.num_iterations, self.num_iterations] 

        #get prediction for both
        prediction = self.forward(x_input)[0]
        prediction_clone = self.forward(x_input_clone)[0]

        #calculate difference
        difference = torch.abs(prediction - prediction_clone)
        sensitivity = difference.sum()

        return sensitivity.item()/(num_samples*size*size), num_samples*size*size
        
    def get_config(self):
        """
        Returns the configuration of the model.

        Returns:
            dict: Dictionary containing the configuration of the model.
        """
        d = {'model': 'CA_CNN_Convolutional', 'num_iterations': self.num_iterations, 'num_classes': self.num_classes, 'residual': self.residual}
        d["start_perturbation_sensitivity"] = self.calc_perturbation_sensitivity()[0]
        d["use_bn"] = self.use_bn
        return d

if __name__=="__main__":

    net = CA_CNN_Convolutional(32, 3, 2, extra_depth=1, residual=True)
    net.to('cuda')

    for i in range(4,7):
        for k in range(4):
            print(16*(2**i),net.calc_perturbation_sensitivity(num_samples=16*(2**i)))