import torch
import logging
import numpy as np
from pytorch_lightning.utilities import rank_zero_only
from torch_geometric.data import Data, Batch
from torch_cluster import radius_graph, knn_graph

def to_coords(x: torch.Tensor, t: torch.Tensor):
    """
    Transforms the coordinates to a tensor X of shape [time, space, 2].
    Args:
        x: spatial coordinates
        t: temporal coordinates
    Returns:
        torch.Tensor: X[..., 0] is the space coordinate (in 2D)
                      X[..., 1] is the time coordinate (in 2D)
    """
    x_, t_ = torch.meshgrid(x, t)
    x_, t_ = x_.T, t_.T
    return torch.stack((x_, t_), -1)

def make_coord(shape, ranges=None, flatten=True):
    """ 
    Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret

def get_logger(name=__name__):
    """
    Initializes multi-GPU-friendly python command line logger.
    https://github.com/ashleve/lightning-hydra-template/blob/8b62eef9d0d9c863e88c0992595688d6289d954f/src/utils/utils.py#L12
    """

    logger = logging.getLogger(name)

    # this ensures all logging levels get marked with the rank zero decorator
    # otherwise logs would get multiplied for each GPU process in multi-GPU setup
    for level in (
        "debug",
        "info",
        "warning",
        "error",
        "exception",
        "fatal",
        "critical",
    ):
        setattr(logger, level, rank_zero_only(getattr(logger, level)))

    return logger

def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (C, L) or (C, H, W)
    """
    if len(img.shape) == 2:
        coord = make_coord(img.shape[-1:])
    elif len(img.shape) == 3:
        coord = make_coord(img.shape[-1:])
    else:
        NotImplementedError
    rgb = img.view(img.shape[0], -1).permute(1, 0)
    return coord, rgb


def get_edge_index_2d(x, regular, neighbors):
    if regular == True:
        dx = torch.dist(x[0], x[1])
        radius = neighbors/8*np.sqrt(2) * dx + 0.0001
        edge_index = radius_graph(x, r=radius, loop=False)
    else:
        edge_index = knn_graph(x, k=neighbors,  loop=False)

    return edge_index

def get_edge_index_period_2d(x, regular, neighbors):
    if regular == True:
        N = x.size(0)
    
        # Create periodic images
        offsets = [-L, 0, L]
        images = [x]
        for dx in offsets:
            for dy in offsets:
                if dx == 0 and dy == 0:
                    continue
                shift = torch.tensor([[dx, dy]], dtype=torch.float).to(x.device)
                images.append(x + shift)
        all_images = torch.cat(images, dim=0)

        # Compute edge indices for all images
        dx = torch.dist(x[0], x[1])
        radius = neighbors/8*np.sqrt(2) * dx + 0.0001
        edge_index = radius_graph(all_images, r=radius, loop=False)
    
        # Filter and adjust edge indices
        mask = edge_index[0] < N  # keep only edges originating from original points
        edge_index = edge_index[:, mask]
        edge_index[1] = edge_index[1] % N  # translate target indices to original points
    
        # Ensure unique edges
        edge_index, _ = edge_index.unique(dim=1, sorted=False)

    else:
        edge_index = knn_graph(x, k=neighbors, loop=False)
        
    return edge_index

def generate_torchgeom_dataset_2d(data, time_slice, step, regular, neighbors, train=True):
    """Returns dataset that can be used to train our model.
    
    Args:
        data (dict): Data dictionary with keys t, x, u, bcs_dicts.
    Returns:
        dataset (list): Array of torchgeometric Data objects.
    """

    dataset = []

    print("Data Generation Start..!")
    edge_index = torch.Tensor(get_edge_index_2d(data[0]['x'], regular, neighbors)).long() # fixed sensor
    n_sims = len(data)
    for sim_ind in range(n_sims):
        n_sim=n_sims//10
        if (sim_ind%n_sim==0):
            print( '{:4d} / {:4d} (Done..)' .format(sim_ind, n_sims))
        
        u0 = data[sim_ind]['u'][:, 0:time_slice]
        if train:
            u = data[sim_ind]['u'][:, 0:time_slice*(step+1)]
        else :
            u = data[sim_ind]['u'][:, time_slice:time_slice*(step+1)]
        pos = data[sim_ind]['x']
        t = data[sim_ind]['t'][0:time_slice*(step+1)]
        u0 = np.concatenate([u0, pos], axis=1)
#        edge_index = torch.Tensor(get_edge_index(pos)).long() #Varying Sensor
        
        u0=torch.as_tensor(u0).type(torch.float)
        u=torch.as_tensor(u).type(torch.float)
        t=torch.as_tensor(t).type(torch.float)
        tg_data = Data(
            u0=u0,
            edge_index=edge_index,
            u=u,
            t=t,
            pos=pos,
            sim_ind=torch.tensor(sim_ind, dtype=torch.long)
        )
        
        dataset.append(tg_data)     

    return dataset

# Custom collate function
def collate_fn(batch):
    return Batch.from_data_list(batch)

def get_edge_index(x, regular, neighbors):
    if regular:
        dx = x[1] - x[0]
        radius = (neighbors/2) * dx + 0.0001
        edge_index = radius_graph(x, r=radius, loop=False)
    else:
        edge_index = knn_graph(x, k=neighbors,  loop=False)

    return edge_index


def get_edge_index_period_1d(x, regular, neighbors):
    if regular:
        # Calculate distances while considering the periodic boundary
        dist_matrix = torch.abs(x - x.t())
        dist_matrix = torch.min(dist_matrix, L - dist_matrix)

        edge_index_list = []
        for idx, dists in enumerate(dist_matrix):
            # Exclude the point itself
            _, sorted_indices = dists.sort()
            neighbors_idx = sorted_indices[1:neighbors+1]  # We exclude the point itself
            source = torch.full_like(neighbors_idx, idx)
            edge_index_list.append(torch.stack([source, neighbors_idx], dim=0))
            
        edge_index = torch.cat(edge_index_list, dim=1)
            
    else:
        edge_index = knn_graph(x, k=neighbors,  loop=False)  
        
    return edge_index


def generate_torchgeom_dataset(data, time_slice, step, regular, neighbors, train=True):
    """Returns dataset that can be used to train our model.
    
    Args:
        data (dict): Data dictionary with keys t, x, u, bcs_dicts.
    Returns:
        dataset (list): Array of torchgeometric Data objects.
    """

    dataset = []

    print("Data Generation Start..!")
    edge_index = torch.Tensor(get_edge_index(data[0]['x'], regular, neighbors)).long() # fixed sensor
    n_sims = len(data)
    for sim_ind in range(n_sims):
        n_sim=n_sims//10
        if (sim_ind%n_sim==0):
            print( '{:4d} / {:4d} (Done..)' .format(sim_ind, n_sims))
        
        u0 = data[sim_ind]['u'][:, 0:time_slice]
        if train:
            u = data[sim_ind]['u'][:, 0:time_slice*(step+1)]
        else :
            u = data[sim_ind]['u'][:, time_slice:time_slice*(step+1)]
        pos = data[sim_ind]['x']
        t = data[sim_ind]['t'][0:time_slice*(step+1)]
        u0 = np.concatenate([u0, pos], axis=1)
#        edge_index = torch.Tensor(get_edge_index(pos)).long() #Varying Sensor
        
        u0=torch.as_tensor(u0).type(torch.float)
        u=torch.as_tensor(u).type(torch.float)
        t=torch.as_tensor(t).type(torch.float)
        tg_data = Data(
            u0=u0,
            edge_index=edge_index,
            u=u,
            t=t,
            pos=pos,
            sim_ind=torch.tensor(sim_ind, dtype=torch.long)
        )
        
        dataset.append(tg_data)     

    return dataset

# Custom collate function
def collate_fn(batch):
    return Batch.from_data_list(batch)

def compute_test_error(model, test_loader):
    # Move model to CPU
    model = model.to("cpu")
    
    # Ensure model is in eval mode
    model.eval()
    
    # Accumulators for loss and relative error
    total_test_loss = 0.0
    total_test_rel_error = 0.0
    
    # Disable gradients to save memory
    with torch.no_grad():
        num_batches = 0
        for batch in test_loader:
            results = model.test_step(batch, batch_idx=num_batches)
            total_test_loss += results['test_loss'].item()
            total_test_rel_error += results['test_rel_error'].item()
            num_batches += 1
            
    # Compute average values
    avg_test_loss = total_test_loss / num_batches
    avg_test_rel_error = total_test_rel_error / num_batches
            
    # Print or return the results
    print(f"Average Test Loss: {avg_test_loss}, Average Test Relative Error: {avg_test_rel_error}")
    return avg_test_loss, avg_test_rel_error

def rel_L2_error(pred, true):
    return (torch.sum((true-pred)**2, dim=-1)/torch.sum((true)**2, dim=-1))**0.5