import torch
import time
from tqdm import tqdm
from torch.nn import Module
from torch.nn import Conv2d, BatchNorm2d, Linear, BatchNorm1d
from torch.nn import ReLU
from torch.nn import LogSoftmax, Sigmoid, Softmax
from torch.nn import ModuleList, Sequential
from torch import count_nonzero
from torch.nn import CrossEntropyLoss, MSELoss, NLLLoss
from torch.nn.functional import one_hot
from torch.optim import SGD, Adam
from einops import rearrange
import numpy as np
from torch.nn import functional as F



import torch
import torch.nn as nn

import numpy as np


def extract_ca_pairs_torch(
    input_batch: torch.Tensor,
    k: int,
    output_batch = None
):
    """
    Extracts CA pairs from a batch of grids using efficient PyTorch operations.

    This version uses circular padding and tensor unfolding to create sliding
    window views without any Python loops, making it very fast on CPUs and GPUs.

    Args:
        input_batch: A (B, C_in, H, W) tensor of input grids.
        k: The number of CA steps, defining the dependency radius.
        output_batch: An optional (B, C_out, H, W) tensor of output grids.

    Returns:
        If output_batch is None:
            - A (B*H*W, C_in, patch_size, patch_size) tensor of input patches.
        If output_batch is provided:
            - A tuple containing the input patches tensor and a 
              (B*H*W, C_out) tensor of corresponding output values.
    """
    B, C_in, H, W = input_batch.shape
    patch_size = 2 * k + 1

    # 1. Pad the input batch with circular (wraparound) padding
    # The padding tuple is (pad_left, pad_right, pad_top, pad_bottom)
    padded_input = F.pad(input_batch, (k, k, k, k), mode='circular')

    # 2. Use .unfold() twice to create sliding window views efficiently
    # Unfold height (dim 2), then width (dim 3 of the result)
    patches = padded_input.unfold(2, patch_size, 1).unfold(3, patch_size, 1)
    # Shape of `patches`: (B, C_in, H, W, patch_size, patch_size)

    # 3. Permute and reshape to get the final batch of patches
    # (B, C, H, W, p, p) -> (B, H, W, C, p, p) -> (B*H*W, C, p, p)
    patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(
        B * H * W, C_in, patch_size, patch_size
    )

    if output_batch is None:
        return patches

    # 4. Process the output batch if provided
    # (B, C_out, H, W) -> (B, H, W, C_out) -> (B*H*W, C_out)
    C_out = output_batch.shape[1]
    #print(output_batch.shape)
    #outputs = output_batch.permute(0, 2, 3, 1).reshape(B * H * W, C_out)
    #print(patches.shape)
    #print(output_batch.shape)
    #assert False
    return patches, output_batch.view(-1)


def extract_ca_pairs_batch_channel_first(
    input_batch: np.ndarray, 
    output_batch: np.ndarray, 
    k: int
) -> tuple[np.ndarray, np.ndarray]:
    """
    Extracts CA pairs from a batch of grids with a channel-first format.

    Args:
        input_batch: The (batch_size, num_channels, height, width) batch of inputs.
        output_batch: The (batch_size, num_out_channels, height, width) batch of outputs.
        k: The number of CA steps, defining the dependency radius.

    Returns:
        A tuple with the consolidated array of input patches and output values.
    """
    

    batch_size, num_channels, n, _ = input_batch.shape
    patch_size = 2 * k + 1
    
    all_patches = []
    all_outputs = []

    # Process each item in the batch
    for i in range(batch_size):
        input_grid = input_batch[i]   # Shape: (channels, n, n)
        output_grid = output_batch[i] # Shape: (out_channels, n, n)

        # Pad only the spatial dimensions (last two axes), not the channel dimension
        padded_grid = np.pad(input_grid, pad_width=((0, 0), (k, k), (k, k)), mode='wrap')
        
        # Iterate through every spatial location of the original grid
        for r in range(n):
            for c in range(n):
                # Extract the patch, keeping all channels
                patch = padded_grid[:, r : r + patch_size, c : c + patch_size]
                all_patches.append(patch)
                
                # Extract the corresponding output pixel for all output channels
                output_pixel = output_grid[..., r, c]
                all_outputs.append(output_pixel)

    return np.array(all_patches), np.array(all_outputs)



class LinearClassifier(nn.Module):
    def __init__(self, num_layers_hidden,k=5, hidden_dim=128, num_classes=2):
        super().__init__()
        self.k = k
        self.num_layers = num_layers_hidden
        self.input_channels = (2*k+1)**2
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        layers = nn.ModuleList()
        layers.append(Linear(self.input_channels, hidden_dim))
        layers.append(ReLU())
        for _ in range(num_layers_hidden):
            layers.append(Linear(hidden_dim, hidden_dim))
            layers.append(ReLU())
        layers.append(Linear(hidden_dim, num_classes))
        self.layers = layers
        

    def forward_inner(self, x):
        # x shape: (batch, n, 2)
        #print(x.shape)
        
        
        for layer in  self.layers:
            x = layer(x)
        return x

    def forward(self, x, labels=None):
        #print(x.shape)
        #print(labels.shape)
        #start_time = time.time()  # Start time for the function call
        
        output = extract_ca_pairs_torch(x, self.k, labels)
        if labels is not None:
            input_patches, output_values = output
        else:
            input_patches = output
            output_values = None
        
        #print(input_patches.shape)
        #print(output_values.shape)

        input_as_tokens = rearrange(input_patches, 'b c h w -> b (h w) c')
        input_as_vec = input_as_tokens[:,:,0]- input_as_tokens[:,:,1]  # Subtract the center value from the rest

        #start_forward_time = time.time()  # Start time for the forward call
        prediction = self.forward_inner(input_as_vec)
        #forward_duration = time.time() - start_forward_time  # Calculate forward call duration
        if output_values is None:
            return prediction,

        original_shape = labels.shape

        #print(prediction.shape)
        #print(output_values.shape)

        loss = F.cross_entropy(prediction, output_values)
        accuracy = torch.mean(1.0 * (torch.argmax(prediction, dim=1) == output_values))
        #print(original_shape)
        #print(prediction.shape)
        prediction = prediction.view(-1, original_shape[1], original_shape[2], 2)  # Reshape to match output values
        prediction = rearrange(prediction, 'b h w c -> b c h w')  # Rearrange to (batch, channels, height, width)

        return prediction, loss, accuracy
        

    def calc_perturbation_sensitivity(self, num_samples=8, device='cuda'):
        """
        Calculates the sensitivity of the model to perturbations in the input.

        Args:
            num_samples (int): Number of samples to use for the sensitivity calculation.

        Returns:
            tuple: Sensitivity value (float) and the total number of elements considered (int).
        """
        size = self.k*2+1

        x_input = torch.randint(0, self.num_classes, (num_samples, size, size))
        x_input = one_hot(x_input.to(torch.int64), num_classes=self.num_classes)
        x_input =x_input.permute(0, 3, 1, 2)
        x_input = x_input.to(torch.float32).to(device=device)

        #change center value in clone
        x_input_clone = x_input.clone()
        x_input_clone[:, :, self.k, self.k] = 1-x_input[:, :, self.k, self.k] 

        #get prediction for both
        prediction = self.forward(x_input)[0]
        prediction_clone = self.forward(x_input_clone)[0]

        #calculate difference
        difference = torch.abs(prediction - prediction_clone)
        sensitivity = difference.sum()

        return sensitivity.item()/(num_samples*size*size), num_samples*size*size
        
    def get_config(self):
        """
        Returns the configuration of the model.

        Returns:
            dict: Dictionary containing the configuration of the model.
        """
        d = {'model': 'Transformer', 'num_iterations': self.k, 'num_classes': self.num_classes}
        d["start_perturbation_sensitivity"] = self.calc_perturbation_sensitivity()[0]
        d["hidden_dim"] = self.hidden_dim
        d["num_layers"] = self.num_layers
        return d




class PixelTransformerClassifier(nn.Module):
    def __init__(self, k=5, input_channels=2, embed_dim=128, num_heads=8, num_layers=6, num_classes=2):
        super().__init__()
        n_size = 2*k+1
        seq_len = n_size * n_size
        self.k=k
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.num_classes = num_classes
        # 1. Linear embedding layer for each pixel
        self.pixel_embedding = nn.Linear(input_channels, embed_dim)

        # 2. CLS token and positional embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        # +1 for the CLS token
        self.pos_embed = nn.Parameter(torch.randn(1, seq_len + 1, embed_dim))

        # 3. Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # 4. Classifier Head
        self.classifier = nn.Linear(embed_dim, num_classes)
        

    def forward_inner(self, x):
        # x shape: (batch, n, 2)
        #print(x.shape)
        batch_size = x.shape[0]

        
        # Embedded shape: (batch, n, embed_dim)
        x = self.pixel_embedding(x)

        # 2. Prepend CLS token and add positional embeddings
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embed

        # 3. Pass through Transformer Encoder
        x = self.transformer_encoder(x)

        # 4. Take the output of the CLS token for classification
        cls_output = x[:, 0]
        out = self.classifier(cls_output)
        return out

    def forward(self, x, labels=None):
        #print(x.shape)
        #print(labels.shape)
        #start_time = time.time()  # Start time for the function call
        
        output = extract_ca_pairs_torch(x, self.k, labels)
        if labels is not None:
            input_patches, output_values = output
        else:
            input_patches = output
            output_values = None
        
        input_as_tokens = rearrange(input_patches, 'b c h w -> b (h w) c')

        #start_forward_time = time.time()  # Start time for the forward call
        prediction = self.forward_inner(input_as_tokens)
        #forward_duration = time.time() - start_forward_time  # Calculate forward call duration
        if output_values is None:
            return prediction,

        original_shape = labels.shape
        loss = F.cross_entropy(prediction, output_values)
        accuracy = torch.mean(1.0 * (torch.argmax(prediction, dim=1) == output_values))
        #print(original_shape)
        #print(prediction.shape)
        prediction = prediction.view(-1, original_shape[1], original_shape[2], 2)  # Reshape to match output values
        prediction = rearrange(prediction, 'b h w c -> b c h w')  # Rearrange to (batch, channels, height, width)
        #prediction = prediction.view(original_shape[0],2,original_shape[1],original_shape[2])  # Reshape to match output values
        #total_duration = time.time() - start_time  # Calculate total function duration
        #print(f"Total function runtime: {total_duration:.6f} seconds")
        #print(f"Forward call runtime: {forward_duration:.6f} seconds")

        return prediction, loss, accuracy
        

    def calc_perturbation_sensitivity(self, num_samples=8, device='cuda'):
        """
        Calculates the sensitivity of the model to perturbations in the input.

        Args:
            num_samples (int): Number of samples to use for the sensitivity calculation.

        Returns:
            tuple: Sensitivity value (float) and the total number of elements considered (int).
        """
        size = self.k*2+1

        x_input = torch.randint(0, self.num_classes, (num_samples, size, size))
        x_input = one_hot(x_input.to(torch.int64), num_classes=self.num_classes)
        x_input =x_input.permute(0, 3, 1, 2)
        x_input = x_input.to(torch.float32).to(device=device)

        #change center value in clone
        x_input_clone = x_input.clone()
        x_input_clone[:, :, self.k, self.k] = 1-x_input[:, :, self.k, self.k] 

        #get prediction for both
        prediction = self.forward(x_input)[0]
        prediction_clone = self.forward(x_input_clone)[0]

        #calculate difference
        difference = torch.abs(prediction - prediction_clone)
        sensitivity = difference.sum()

        return sensitivity.item()/(num_samples*size*size), num_samples*size*size
        
    def get_config(self):
        """
        Returns the configuration of the model.

        Returns:
            dict: Dictionary containing the configuration of the model.
        """
        d = {'model': 'Transformer', 'num_iterations': self.k, 'num_classes': self.num_classes}
        d["start_perturbation_sensitivity"] = self.calc_perturbation_sensitivity()[0]
        d["embed_dim"] = self.embed_dim
        d["num_heads"] = self.num_heads
        d["num_layers"] = self.num_layers
        return d


if __name__=="__main__":
    # Usage
    model = PixelTransformerClassifier(k=2, num_layers=2, embed_dim=64, num_heads=4)
    
    #model = LinearClassifier(num_layers_hidden=2, k=2, hidden_dim=64, num_classes=2)
    model.cuda()
    test_input = torch.randn(4, 2, 11, 11) # (batch,channels, n, n)
    test_input=test_input.cuda()
    output = model(test_input,labels=test_input[:,:,:,:]) # output shape: (16, 1)
    #print(output)
    print(model.calc_perturbation_sensitivity(num_samples=16))