import re
import os
import contextlib
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
from matplotlib.backends.backend_pdf import PdfPages

def save_single_tensor_image(tensors, image_id, folder_path="output_images", image_name="input", cmap='viridis'):
    """
    Save a single tensor image with the specified ID from the batch using a colormap.
    
    Args:
        tensors (torch.Tensor): Batch of tensors to be saved. Shape should be (bs, 128, 128) or (bs, 1, 128, 128).
        image_id (int): The index of the image in the batch to be saved.
        folder_path (str): Folder path to save the image.
        cmap (str): Colormap to apply to the image.
    """
    # Ensure tensors are in the shape (batch_size, 128, 128)
    if tensors.ndim == 4 and tensors.shape[1] == 1:
        # Remove the channel dimension if it exists (from shape [bs, 1, 128, 128] to [bs, 128, 128])
        tensors = tensors.squeeze(1)
    
    # Convert the specified image tensor to NumPy
    img = tensors[image_id].cpu().numpy()
    
    # Apply the colormap using matplotlib
    plt.imshow(img, cmap=cmap)
    plt.axis('off')  # Hide axes
    
    # Save the image to the specified folder
    plt.savefig(f"{folder_path}/image_{image_name}_{image_id}.png", bbox_inches='tight', pad_inches=0)
    plt.close()


#----------------------------------------------------------------------------
def save_tensors_to_pdf(tensor_list, tensor_names, folder_path="output_images", pdf_file_path="output_images.pdf", cmap='viridis', rows_per_page=10):
    """
    Save a list of tensors with their respective names in a PDF file as a grid of images.
    
    Args:
        tensor_list (list): A list of PyTorch tensors. Each tensor should be of shape (batch_size, 128, 128) or (batch_size, 1, 128, 128).
        tensor_names (list): List of names for each tensor in tensor_list. Length should match the number of tensors.
        pdf_file_path (str): Path to save the PDF file.
        cmap (str): Colormap to apply to the images.
        rows_per_page (int): Number of rows (samples) to display per page.
    """
    num_tensors = len(tensor_list)
    batch_size = tensor_list[0].shape[0]  # Assume all tensors have the same batch size
    
    with PdfPages(os.path.join(folder_path, pdf_file_path)) as pdf:
        for start_idx in range(0, batch_size, rows_per_page):
            end_idx = min(start_idx + rows_per_page, batch_size)
            
            fig, axs = plt.subplots(end_idx - start_idx, num_tensors, figsize=(num_tensors * 4, (end_idx - start_idx) * 4))
            fig.tight_layout(pad=3.0)
            
            # If we only have one row, axs won't be a list of lists, so we need to ensure it's always 2D.
            if end_idx - start_idx == 1:
                axs = [axs]
            
            # Loop over samples (rows)
            for row_idx, sample_idx in enumerate(range(start_idx, end_idx)):
                # Loop over tensors (columns)
                for col_idx, tensor in enumerate(tensor_list):
                    # Ensure the tensor is in the correct shape
                    if tensor.ndim == 4 and tensor.shape[1] == 1:
                        tensor = tensor.squeeze(1)
                    
                    # Convert the specified sample to NumPy
                    img = tensor[sample_idx].cpu().numpy()
                    # Plot the image with a colormap
                    im = axs[row_idx][col_idx].imshow(img, cmap=cmap)
                    axs[row_idx][col_idx].axis('off')

                    # Add color bar next to each image
                    cbar = fig.colorbar(im, ax=axs[row_idx][col_idx], fraction=0.046, pad=0.04)
                    cbar.ax.tick_params(labelsize=8)
                    
                    # Set the column title with the tensor name for the first row
                    if row_idx == 0:
                        axs[row_idx][col_idx].set_title(tensor_names[col_idx])
            
            # Save the current figure to the PDF
            pdf.savefig(fig)
            plt.close(fig)


#----------------------------------------------------------------------------
def compute_relative_error(predicted, true, p=2):
    """
    Compute the relative error between predicted and true tensors for each sample in a batch.

    Args:
    predicted (torch.Tensor): The predicted tensor of shape (batch_size, ...).
    true (torch.Tensor): The ground truth tensor of shape (batch_size, ...).

    Returns:
    torch.Tensor: A tensor containing the relative error for each sample in the batch.
    """
    # Ensure the tensors are of the same shape
    assert predicted.shape == true.shape, "Predicted and True tensors must have the same shape."
    num_examples = predicted.size()[0]
    
    # Compute the L2 norm of the difference (numerator) for each sample
    error_norm = torch.norm(predicted.reshape(num_examples,-1) - true.reshape(num_examples,-1), p=2, dim=1)  

    # Compute the L2 norm of the ground truth (denominator) for each sample
    true_norm = torch.norm(true.reshape(num_examples,-1), p=2, dim=1)
    
    # Compute the relative error for each sample in the batch
    relative_error = error_norm / true_norm
    return relative_error
#---------------------------------------------------------------------------
def compute_relative_error_componentwise(predicted, true, p=2):
    """
    Compute the relative error between predicted and true tensors for each sample in a batch.

    Args:
    predicted (torch.Tensor): The predicted tensor of shape (batch_size, ...).
    true (torch.Tensor): The ground truth tensor of shape (batch_size, ...).

    Returns:
    torch.Tensor: A tensor containing the relative error for each sample in the batch.
    """
    # Ensure the tensors are of the same shape
    assert predicted.shape == true.shape, "Predicted and True tensors must have the same shape."
    num_examples = predicted.size()[0]
    

    error = (predicted.reshape(num_examples,-1) - true.reshape(num_examples,-1))/true.reshape(num_examples,-1)
    error_norm = torch.norm(error, p=2, dim=1)
    # print("error shape:", error_norm.shape)
    
    return error_norm
#---------------------------------------------------------------------------
def compute_mean_absolute_error(predicted, true):
    """
    Compute the mean absolute error (MAE) between predicted and true tensors for each sample in a batch.
    
    Args:
    predicted (torch.Tensor): The predicted tensor of shape (batch_size, ...).
    true (torch.Tensor): The ground truth tensor of shape (batch_size, ...).
    
    Returns:
    torch.Tensor: A tensor containing the MAE for each sample in the batch.
    """
    assert predicted.shape == true.shape, "Predicted and True tensors must have the same shape."
    
    mae = torch.mean(torch.abs(predicted - true), dim=tuple(range(1, predicted.dim())))
    # print("mae shape: ", mae.shape)
    return mae
#---------------------------------------------------------------------------
def compute_mean_squared_error(predicted, true):
    """
    Compute the mean squared error (MSE) between predicted and true tensors for each sample in a batch.
    
    Args:
    predicted (torch.Tensor): The predicted tensor of shape (batch_size, ...).
    true (torch.Tensor): The ground truth tensor of shape (batch_size, ...).
    
    Returns:
    torch.Tensor: A tensor containing the MSE for each sample in the batch.
    """
    assert predicted.shape == true.shape, "Predicted and True tensors must have the same shape."
    
    mse = torch.mean((predicted - true) ** 2, dim=tuple(range(1, predicted.dim())))
    # print("mae shape: ", mse.shape)
    return mse
#---------------------------------------------------------------------------
def calculate_metrics(predicted, true):
    """
    Calculate various metrics including relative error, mean absolute error, and mean squared error.
    
    Args:
    predicted (torch.Tensor): The predicted tensor of shape (batch_size, ...).
    true (torch.Tensor): The ground truth tensor of shape (batch_size, ...).
    
    Returns:
    dict: A dictionary containing the metrics for each sample in the batch.
    """
    metrics = {
        "relative_error": compute_relative_error(predicted, true),
        "componentwise_reltaive_error": compute_relative_error_componentwise(predicted, true),
        "mean_absolute_error": compute_mean_absolute_error(predicted, true),
        "mean_squared_error": compute_mean_squared_error(predicted, true),
    }
    return metrics
#---------------------------------------------------------------------------

#### Normalization code from FNO github repo
# normalization, pointwise gaussian
class UnitGaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(UnitGaussianNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x, 0)
        self.std = torch.std(x, 0)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std + self.eps # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx] + self.eps  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:,sample_idx]+ self.eps # T*batch*n
                mean = self.mean[:,sample_idx]

        # x is in shape of batch*n or T*batch*n
        x = (x * std) + mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()
#---------------------------------------------------------------------------
# normalization, Gaussian
class ScaledGaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(ScaledGaussianNormalizer, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x, 0)
        self.std = torch.std(x, 0)
        self.eps = eps

    def encode(self, x):
        device = x.device
        mean = self.mean.to(device)
        std = self.std.to(device)
        x = (x - mean) / (std + self.eps)
        return x*0.5

    def decode(self, x, sample_idx=None):
        device = x.device
        if sample_idx is None:
            std = self.std.to(device) + self.eps # n
            mean = self.mean.to(device)
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx].to(device) + self.eps  # batch*n
                mean = self.mean[sample_idx].to(device)
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:,sample_idx].to(device) + self.eps # T*batch*n
                mean = self.mean[:,sample_idx].to(device)

        # x is in shape of batch*n or T*batch*n
        x = x*2 # reversing the scaling
        x = (x * std) + mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()

#---------------------------------------------------------------------------
# normalization, Gaussian
class GaussianNormalizer(object):
    def __init__(self, x, eps=0.00001):
        super(GaussianNormalizer, self).__init__()

        self.mean = torch.mean(x)
        self.std = torch.std(x)
        self.eps = eps

    def encode(self, x):
        x = (x - self.mean) / (self.std + self.eps)
        return x

    def decode(self, x, sample_idx=None):
        x = (x * (self.std + self.eps)) + self.mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()
#---------------------------------------------------------------------------
# normalization, scaling by range
class RangeNormalizer(object):
    def __init__(self, x, low=0.0, high=1.0):
        super(RangeNormalizer, self).__init__()
        mymin = torch.min(x, 0)[0].view(-1)
        mymax = torch.max(x, 0)[0].view(-1)

        self.a = (high - low)/(mymax - mymin)
        self.b = -self.a*mymax + high

    def encode(self, x):
        s = x.size()
        x = x.view(s[0], -1)
        x = self.a*x + self.b
        x = x.view(s)
        return x

    def decode(self, x):
        s = x.size()
        x = x.view(s[0], -1)
        x = (x - self.b)/self.a
        x = x.view(s)
        return x
#---------------------------------------------------------------------------
class ScaledGaussianNormalizer2(object):
    """Normalization class that normalizes input data to zero mean and 0.5 variance."""
    def __init__(self, x, eps=0.00001):
        super(ScaledGaussianNormalizer2, self).__init__()

        # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T
        self.mean = torch.mean(x, 0)
        self.std = torch.std(x, 0)
        self.eps = eps
        self.std = torch.clamp(self.std, min=eps)
        if self.std.min() <= eps:
            print(f"Warning: Some std values were clamped to {eps}")

    def encode(self, x):
        x = (x - self.mean) * (0.5 / self.std)
        return x

    def decode(self, x, sample_idx=None):
        if sample_idx is None:
            std = self.std  # n
            mean = self.mean
        else:
            if len(self.mean.shape) == len(sample_idx[0].shape):
                std = self.std[sample_idx]  # batch*n
                mean = self.mean[sample_idx]
            if len(self.mean.shape) > len(sample_idx[0].shape):
                std = self.std[:,sample_idx] # T*batch*n
                mean = self.mean[:,sample_idx]

        x = (x / (0.5 / std)) + mean
        return x

    def cuda(self):
        self.mean = self.mean.cuda()
        self.std = self.std.cuda()

    def cpu(self):
        self.mean = self.mean.cpu()
        self.std = self.std.cpu()

#---------------------------------------------------------------------------
def get_darcy_loss(a, u, device=torch.device('cuda')):
    """
    a, u  ->  [B, 1, H, W]  (batched version)
    Returns the PDE residual for each sample in the batch, shape [B, H, W].
    """
    """Return the loss of the Darcy Flow equation and the observation loss."""
    # Build derivative filters just once, on the correct device
    deriv_x = torch.tensor([[1, 0, -1]], dtype=torch.float64, device=device).view(1, 1, 1, 3) / 2
    deriv_y = torch.tensor([[1], [0], [-1]], dtype=torch.float64, device=device).view(1, 1, 3, 1) / 2

    u = u.to(device)
    a = a.to(device)

    # Perform conv2d for horizontal and vertical derivatives
    grad_x_next_x = F.conv2d(u, deriv_x, padding=(0, 1))
    grad_x_next_y = F.conv2d(u, deriv_y, padding=(1, 0))
    # Multiply gradient by 'a'
    grad_x_next_x = a * grad_x_next_x
    grad_x_next_y = a * grad_x_next_y
    # Second derivatives
    result = F.conv2d(grad_x_next_x, deriv_x, padding=(0, 1)) + F.conv2d(grad_x_next_y, deriv_y, padding=(1, 0))
    pde_loss = result + 1.0
    pde_loss = pde_loss.squeeze(1)
    # print(pde_loss.shape)
    
    return pde_loss
#---------------------------------------------------------------------------
def count_parameters(layer):
    """Compute the total number of trainable parameters in a layer."""
    return sum(p.numel() for p in layer.parameters() if p.requires_grad)
#---------------------------------------------------------------------------
def analyze_uno(model):
    from training.networks import UNOBlock, SpectralConv2d_opscaling
    """Analyze encoder and decoder modes in SongUNO model."""
    print("\n--- Encoder Blocks ---")
    for name, block in model.enc.items():
        if isinstance(block, UNOBlock):
            print(f"{name}: Modes={block.min_n_modes}, Params={count_parameters(block)}")
    
    print("\n--- Decoder Blocks ---")
    for name, block in model.dec.items():
        if isinstance(block, UNOBlock):
            print(f"{name}: Modes={block.min_n_modes}, Params={count_parameters(block)}")
#---------------------------------------------------------------------------
def analyze_spectralconv2d(model):
    from training.networks import UNOBlock, SpectralConv2d_opscaling
    """Analyze spectral convolution layers in SongUNO."""
    print("\n--- SpectralConv2d Layers ---")
    for name, block in model.enc.items():
        if isinstance(block, UNOBlock) and isinstance(block.conv0, SpectralConv2d_opscaling):
            print(f"{name}: Modes={block.conv0.modes1, block.conv0.modes2}, Params={count_parameters(block.conv0)}")
    
    for name, block in model.dec.items():
        if isinstance(block, UNOBlock) and isinstance(block.conv0, SpectralConv2d_opscaling):
            print(f"{name}: Modes={block.conv0.modes1, block.conv0.modes2}, Params={count_parameters(block.conv0)}")
#---------------------------------------------------------------------------