import torch
import wandb
from pathlib import Path
import omegaconf
from omegaconf import DictConfig
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont

def extract_output_dir(config: DictConfig) -> Path:
    '''
    Extracts path to output directory created by Hydra as pathlib.Path instance
    '''
    date = '/'.join(list(config._metadata.resolver_cache['now'].values()))
    output_dir = Path.cwd() / 'outputs' / date
    return output_dir

def preprocess_config(config):
    config.exp.log_dir = extract_output_dir(config)

def setup_wandb(config):
    group, name = str(config.exp.log_dir).split('/')[-2:]
    wandb_config = omegaconf.OmegaConf.to_container(
        config, resolve = True, throw_on_missing = True
    )
    name = os.getenv("WANDB_RUN_NAME", name)
    
    if "id" in config.wandb.keys() and config.wandb.id is not None:
        run_id = config.wandb.id
    else:
        run_id = None
    
    if config.wandb.tags is not None:
        tags = [config.wandb.tags]
    else:
        tags = None

    return wandb.init(
        project = config.wandb.project,
        entity = config.wandb.entity,
        dir = config.exp.log_dir,
        group = group,
        name = name,
        config = wandb_config,
        sync_tensorboard = True,
        tags = tags,
        id = run_id
    )
            

def get_batch_to_inp(config, batch_imgs, batch_maps):
    '''
    Repeats each image and map as indicated by the
    config.exp.n_inpaints argument.
    '''
    n_inp = config.exp.n_inpaints
    batch_imgs_rep = batch_imgs.repeat_interleave(n_inp, 0)
    batch_maps_rep = batch_maps.repeat_interleave(n_inp, 0)
    return batch_imgs_rep, batch_maps_rep


def get_target_id(config, batch_pred_labels):
    '''
    target_id indicates class for which attribution map is generated
    for multilabel, this is determined by the user
    for multiclass, this is determined by the classifier's prediction
        if config.exp.target_id is None, else it is equal to it
    '''
    if config.exp.task == 'multiclass':

        if config.exp.target_id is not None:
            target = torch.ones_like(batch_pred_labels) * config.exp.target_id

        else:
            target = batch_pred_labels

    elif config.exp.task == 'multilabel':
        # here we assume that a multilabel classifier was converted to a
        # model with (p, 1 - p) output pair, where p is the probability of
        # the positive target class. therefore, no matter which target label was
        # chosen, we always compute the attribution for p.
        target = torch.zeros_like(batch_pred_labels)

    else:
        raise NotImplementedError('Task not yet implemented')

    return target


def unlist(item):
    if hasattr(item, '__len__') and hasattr(item, '__getitem__'):
        # we always have the same t across batch
        if isinstance(item, torch.Tensor):
            return item.flatten()[0].item()
        assert len(item) == 1, "Item is not a singleton"
        return unlist(item[0])
    return item


def get_timestep_value(relative_value: int | str, max_timestep):
    if isinstance(relative_value, str):
        assert relative_value[-1] == 'p', 'Relative value must be a percentage'
        relative_value = float(relative_value[:-1]) / 100
        relavtive_value_to_process = 1 - relative_value
        return int(max_timestep * relavtive_value_to_process)
    return int(relative_value)

  
def append_dims_view(tensor: torch.tensor, desired_num_dims: int) -> torch.tensor:
    '''
    Appends dimensions to a tensor (as view) until it has the desired number of dimensions.
    Returns a view of the tensor.
    '''
    assert tensor.dim() <= desired_num_dims, 'Tensor already has more dimensions than desired'
    shape = tensor.shape
    new_shape = shape + (1,) * (desired_num_dims - len(shape))
    return tensor.view(new_shape)

def from_m1p1_to_01(x):
    is_m1 = (x.flatten(start_dim = 1).min(dim = 1)[0] < 0.).any()
    if is_m1:
        x = x - x.flatten(start_dim = 1).min(dim = 1)[0].view(-1, 1, 1, 1)
        x = x / x.flatten(start_dim = 1).max(dim = 1)[0].view(-1, 1, 1, 1)
    return x

def from_01_to_m1p1(x):
    is_0 = (x.flatten(start_dim = 1).min(dim = 1)[0] >= 0.).any()
    if is_0:
        x = (x - 0.5) * 2
    return x

def generate_number_image(
        n: int,
        width: int = 64,
        height: int = 16,
        size: int = 15,
        margin_left: int = 4
    ):
    """
    Generate an image with a number in it
    Returns a tensor of shape (3, height, width)
    normalized to [0,1]
    """
    # Create a white background image 
    img = Image.new("RGB", (width, height), (255,255,255))

    # Load a default font
    font = ImageFont.load_default(size=size)

    # Create drawing context
    draw = ImageDraw.Draw(img)

    # Get text size using textbbox (introduced in Pillow 8.0)
    text = str(n)
    bbox = draw.textbbox((0, 0), text, font=font)
    text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]

    # Calculate vertical centering
    text_x = margin_left
    text_y = -2  # Centered vertically

    # Draw the text
    draw.text((text_x, text_y), text, fill=0, font=font)  # Black text

    # Convert to a PyTorch tensor (normalize to [0,1])
    img_tensor = torch.from_numpy(np.array(img, dtype=np.float32) / 255.0)

    return img_tensor.permute(2, 0, 1) # (3, H, W)


def list_from_string(string: str, trim_whitespace: bool = True, dtype: type = str) -> list:
    """
    Converts a string representation of a list to an actual list.
    Example: "1,2,3" -> [1, 2, 3]
    Args:
        string (str): The string to convert.
        trim_whitespace (bool): Whether to trim whitespace from each item.
        dtype (type): The type to convert each item to (default is str).
        Supported: str, int, float
    """
    if not string:
        return []
    items = string.split(',')
    if trim_whitespace:
        items = [item.strip() for item in items]

    if dtype == str:
        return items
    if dtype == int:
        if not all(item.isdigit() for item in items):
            raise ValueError("All items must be integers if dtype is int.")
        return [int(item) for item in items]
    if dtype == float:
        if not all(item.replace('.', '', 1).isdigit() for item in items):
            raise ValueError("All items must be floats if dtype is float.")
        return [float(item) for item in items]
