import os
import torch
import wandb
import matplotlib.pyplot as plt
from PIL import Image


def create_circle_filter(radius, size=(512, 512)):
    if type(size) == int:
        size = (size, size)
    centers = torch.tensor([x / 2 for x in size])
    cp = torch.cartesian_prod(*[torch.arange(x) for x in size]).reshape(*size, -1)

    dist = torch.sqrt(torch.sum((cp - centers) ** 2, dim=-1)).float()
    return (dist <= radius).float()


def create_gaussian_filter(size=(512, 512), sigma=50):
    """
    Create a Gaussian filter using PyTorch.

    Parameters:
    size (int): The size of the filter (size x size).
    sigma (float): The standard deviation of the Gaussian kernel.

    Returns:
    torch.Tensor: A 2D Gaussian filter.
    """
    if type(size) == int:
        size = (size, size)
    centers = torch.tensor([x / 2 for x in size])
    cp = torch.cartesian_prod(*[torch.arange(x) for x in size]).reshape(*size, -1)
    square_distance = torch.sum((cp - centers) ** 2, dim=-1)
    gaussian_kernel = torch.exp(-square_distance / (2 * sigma ** 2))
    gaussian_kernel /= gaussian_kernel.sum()

    return gaussian_kernel / torch.max(gaussian_kernel)

def load_model(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"model loaded from checkpoint {model_path}")


def save_image_and_log(file, name, commit=True, delete=False):
    plt.savefig(file)
    if wandb.run:
        # For a single image
        wandb.log({name: wandb.Image(Image.open(file))}, commit=commit)

    plt.close()
    if delete:
        os.remove(file)


def add_defects(target, num_defects=1):
    mod_img = target.clone()
    w = 512
    cp = torch.cartesian_prod(torch.arange(w), torch.arange(w))
    circle_mask = (cp[:, 0] - w / 2) ** 2 + (cp[:, 1] - w / 2) ** 2 <= (w / 2) ** 2

    factors = torch.randint(-100, 100, (num_defects,)) / 20
    thickness = torch.randint(4, 5, (num_defects,))
    offset = torch.randint(-100, 100, (num_defects,))

    f = factors[0]
    of = offset[0]
    th = thickness[0]
    neg_masks = [
        torch.where((torch.abs(cp[:, 1] - f * cp[:, 0] + of) < th) & circle_mask)[0] for f, th, of in
        zip(factors, thickness, offset)
    ]

    factors = torch.randint(0, 100, (num_defects,)) / 20
    thickness = torch.randint(4, 5, (num_defects,))
    offset = torch.randint(-512, 512, (num_defects,))
    pos_masks = [
        torch.where((torch.abs(cp[:, 1] - f * cp[:, 0] + of) < th) & circle_mask)[0] for f, th, of in
        zip(factors, thickness, offset)
    ]

    for m in neg_masks:
        print(m.shape, circle_mask.shape)
        mod_img.view(-1)[m] = 0
    for m in pos_masks:
        mod_img.view(-1)[m] = 1

    return mod_img

def get_uniform_angles(num_angles=180, device=None):
    return torch.linspace(0, 180 * (1 - 1 / num_angles), num_angles, device=device)

def find_match_indices(tensor1, tensor2):
    """
    Finds the indices of the elements in tensor1 that belong to tensor2. It assumes tensor2 \subseteq tensor1
    :param tensor1:
    :param tensor2:
    :return: The required indices
    """
    # Use broadcasting to compare each element in tensor2 with the entire tensor1
    matches = tensor1 == tensor2[:, None]  # This creates a 2D tensor of shape [len(tensor2), len(tensor1)]

    # Find indices where matches occur
    match_indices = matches.nonzero(as_tuple=False)

    # The first column of match_indices contains the indices in tensor2,
    # and the second column contains the corresponding matched indices in tensor1.
    # We're interested in the second column.
    return match_indices[:, 1]