import os
import time
import json
import torch
import random
import hashlib
import numpy as np
import pandas as pd
from typing import Generator
from datasets import load_dataset
from torchvision import transforms
from PIL import Image, ImageFilter

class NoiseConfig:
    '''
    Configuration class for the noise
    '''
    jpeg_ratio: int = 25
    random_crop_ratio: float = 0.6
    random_drop_ratio: float = 0.8
    brightness_factor: float = 6
    gaussian_blur_r: int = 4
    median_blur_k: int = 7
    gaussian_std: float = 0.05
    sp_prob: float = 0.05
    resize_ratio: float = 0.25

def get_num_params(optimizer):
    total_params = sum(
        p.numel() 
        for group in optimizer.param_groups 
        for p in group['params']
        if p.requires_grad
    )
    print(f"Number of optimized parameters: {total_params}")


def load_prompt(path: str) -> Generator[str, None, None]:
    if path == "Gustavosta/Stable-Diffusion-Prompts":
        ds = load_dataset(path)
        yield from ds["test"]["Prompt"]
    elif path.endswith(".json"):
        with open(path, "r", encoding="utf-8") as f:
            dataset = json.load(f)['annotations']
        prompt_key = 'caption'
        yield from (d[prompt_key] for d in dataset)
    elif path.endswith(".csv"):
        df = pd.read_csv(path)
        yield from df["Our GT caption"]

def analyze_data(data: torch.Tensor) -> None:
    print(f"Data type: {type(data)}")
    if not isinstance(data, torch.Tensor):
        raise ValueError("Data is not a tensor")

    print(f"Data dtype: {data.dtype}")
    print(f"Data shape: {data.shape}")
    print(f"Data min: {data.min().item()}")
    print(f"Data max: {data.max().item()}")

def to_tensor(data: Image) -> torch.Tensor:
    np_data = np.array(data)
    data = torch.from_numpy(np_data)
    # analyze_data(data)
    data = data.unsqueeze(0).float()
    data = data / 255
    data = data * 2 - 1
    # b, w, h, c = data.shape
    data = data.permute(0, 3, 1, 2)

    return data

def save_image(image: Image, path: str) -> None:
    image.save(path)
    np_image = np.array(image)
    np.save(path.replace(".png", ".npy"), np_image)

noise_list = ['none', 'jpeg', 'random.crop', 'random.drop', 'resize', 'gaussian.blur', 'median.blur', 'gaussian.noise', 'salt.and.pepper', 'brightness']

def set_random_seed(seed=0):
    torch.manual_seed(seed + 0)
    torch.cuda.manual_seed(seed + 1)
    torch.cuda.manual_seed_all(seed + 2)
    random.seed(seed + 3)

def image_distortion(img, seed, args, choice=None):

    if choice is None:
        choice = random.randint(0, 9)
    noise_name = noise_list[choice]

    if args.jpeg_ratio is not None and choice == 1:
        hash_str = hashlib.md5(str(time.time()).encode()).hexdigest()
        img.save(f"tmp_{args.jpeg_ratio}_{hash_str}.jpg", quality=args.jpeg_ratio)
        img = Image.open(f"tmp_{args.jpeg_ratio}_{hash_str}.jpg")
        os.remove(f"tmp_{args.jpeg_ratio}_{hash_str}.jpg")

    if args.random_crop_ratio is not None and choice == 2:
        set_random_seed(seed)
        width, height, c = np.array(img).shape
        img = np.array(img)
        new_width = int(width * args.random_crop_ratio)
        new_height = int(height * args.random_crop_ratio)
        start_x = np.random.randint(0, width - new_width + 1)
        start_y = np.random.randint(0, height - new_height + 1)
        end_x = start_x + new_width
        end_y = start_y + new_height
        padded_image = np.zeros_like(img)
        padded_image[start_y:end_y, start_x:end_x] = img[start_y:end_y, start_x:end_x]
        img = Image.fromarray(padded_image)

    if args.random_drop_ratio is not None and choice == 3:
        set_random_seed(seed)
        width, height, c = np.array(img).shape
        img = np.array(img)
        new_width = int(width * args.random_drop_ratio)
        new_height = int(height * args.random_drop_ratio)
        start_x = np.random.randint(0, width - new_width + 1)
        start_y = np.random.randint(0, height - new_height + 1)
        padded_image = np.zeros_like(img[start_y:start_y + new_height, start_x:start_x + new_width])
        img[start_y:start_y + new_height, start_x:start_x + new_width] = padded_image
        img = Image.fromarray(img)

    if args.resize_ratio is not None and choice == 4:
        img_shape = np.array(img).shape
        resize_size = int(img_shape[0] * args.resize_ratio)
        img = transforms.Resize(size=resize_size)(img)
        img = transforms.Resize(size=img_shape[0])(img)

    if args.gaussian_blur_r is not None and choice == 5:
        img = img.filter(ImageFilter.GaussianBlur(radius=args.gaussian_blur_r))

    if args.median_blur_k is not None and choice == 6:
        img = img.filter(ImageFilter.MedianFilter(args.median_blur_k))

    if args.gaussian_std is not None and choice == 7:
        img_shape = np.array(img).shape
        g_noise = np.random.normal(0, args.gaussian_std, img_shape) * 255
        g_noise = g_noise.astype(np.uint8)
        img = Image.fromarray(np.clip(np.array(img) + g_noise, 0, 255))

    if args.sp_prob is not None and choice == 8:
        c,h,w = np.array(img).shape
        prob_zero = args.sp_prob / 2
        prob_one = 1 - prob_zero
        rdn = np.random.rand(c,h,w)
        img = np.where(rdn > prob_one, np.zeros_like(img), img)
        img = np.where(rdn < prob_zero, np.ones_like(img)*255, img)
        img = Image.fromarray(img)

    if args.brightness_factor is not None and choice == 9:
        img = transforms.ColorJitter(brightness=args.brightness_factor)(img)

    return img, noise_name