import torch

import json
import random
import numpy as np
from PIL import Image
import matplotlib as m
m.use("Agg")
import requests
from io import BytesIO

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms

import datasets
from datasets import load_dataset, Dataset

from io_utils import *

from statistics import mean

def set_random_seed(seed=0):
    torch.manual_seed(seed + 0)
    torch.cuda.manual_seed(seed + 1)
    torch.cuda.manual_seed_all(seed + 2)
    np.random.seed(seed + 3)
    torch.cuda.manual_seed_all(seed + 4)
    random.seed(seed + 5)


### credit to https://github.com/somepago/DCR
def insert_rand_word(sentence, word):
    sent_list = sentence.split(" ")
    sent_list.insert(random.randint(0, len(sent_list)), word)
    new_sent = " ".join(sent_list)
    return new_sent


def prompt_augmentation(prompt, aug_style, tokenizer=None, args=None, repeat_num=4, embedding_layer=None):
    if aug_style == "rand_numb_add":
        for i in range(repeat_num):
            randnum  = np.random.choice(100000)
            prompt = insert_rand_word(prompt,str(randnum))
    elif aug_style == "rand_word_add":
        for i in range(repeat_num):
            randword = tokenizer.decode(list(np.random.randint(49400, size=1)))
            prompt = insert_rand_word(prompt, randword)
    elif aug_style == "rand_word_repeat":
        wordlist = prompt.split(" ")
        for i in range(repeat_num):
            randword = np.random.choice(wordlist)
            prompt = insert_rand_word(prompt, randword)
    else:
        raise Exception("This style of prompt augmnentation is not written")
    return prompt


def download_image(url):
    response = requests.get(url, stream=True, timeout=2)
    return Image.open(BytesIO(response.content)).convert("RGB")


def get_dataset(dataset_name, data_dir=None):
    if "groundtruth" in dataset_name:
        dataset = load_jsonlines(f"outputs/{dataset_name}/{dataset_name}.jsonl")
        prompt_key = "caption"
    elif "coco" in dataset_name:
        f = open(f"{data_dir}/annotations/captions_train2017.json")
        data = json.load(f)

        id_to_annos = {}
        for anno_i in data["annotations"]:
            if anno_i["image_id"] not in id_to_annos:
                id_to_annos[anno_i["image_id"]] = []

            id_to_annos[anno_i["image_id"]].append(anno_i["caption"])

        all_data = {"image": [], "text": []}

        for curr_img in data["images"]:
            all_data["text"].append(id_to_annos[curr_img["id"]][0])
            file_name = curr_img["file_name"]
            all_data["image"].append(f"{data_dir}/images/train2017/{file_name}")
            
        dataset = Dataset.from_dict(all_data).cast_column("image", datasets.Image())
        prompt_key = "text"
    else:
        dataset = load_dataset(dataset_name)["test"]
        prompt_key = "Prompt"

    return dataset, prompt_key


def measure_CLIP_similarity(images, prompt, model, clip_preprocess, tokenizer, device):
    with torch.no_grad():
        img_batch = [clip_preprocess(i).unsqueeze(0) for i in images]
        img_batch = torch.concatenate(img_batch).to(device)
        image_features = model.encode_image(img_batch)

        text = tokenizer([prompt]).to(device)
        text_features = model.encode_text(text)
        
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        
        return (image_features @ text_features.T).mean(-1)
    

def measure_CLIP_text_similarity(target_prompt, prompt, model, clip_preprocess, tokenizer, device):
    with torch.no_grad():
        text = tokenizer([target_prompt]).to(device)
        text_features = model.encode_text(text)
    
        text = tokenizer([prompt]).to(device)
        text_features_2 = model.encode_text(text)
        
        text_features /= text_features.norm(dim=-1, keepdim=True)
        text_features_2 /= text_features_2.norm(dim=-1, keepdim=True)
        
        return (text_features @ text_features_2.T).mean(-1)


### credit: https://github.com/somepago/DCR
def measure_SSCD_similarity(gt_images, images, model, device):
    ret_transform = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        transforms.Normalize([0.5], [0.5]),
                    ])
    
    gt_images = torch.stack([ret_transform(x.convert('RGB')) for x in gt_images]).to(device)
    images = torch.stack([ret_transform(x.convert('RGB')) for x in images]).to(device)

    with torch.no_grad():
        feat_1 = model(gt_images).clone()
        feat_1 = nn.functional.normalize(feat_1, dim=1, p=2)

        feat_2 = model(images).clone()
        feat_2 = nn.functional.normalize(feat_2, dim=1, p=2)

        return torch.mm(feat_1, feat_2.T)


def even_split_by_words(input_string, num_substrs=4):
    words = input_string.split()
    num_words = len(words)
    part_length = num_words // num_substrs
    remainder = num_words % num_substrs

    substrings = []
    start = 0
    for i in range(num_substrs):
        end = start + part_length + (1 if i < remainder else 0)
        substrings.append(' '.join(words[start:end]))
        start = end

    return substrings


def split_image_into_tiles(image, tile_size=(128, 128)):
    width, height = image.size

    if width != tile_size[0] * 4 or height != tile_size[1] * 4:
        raise ValueError(f"Image dimensions should be {tile_size[0]*4}x{tile_size[1]*4}")

    tiles = []
    for i in range(0, width, tile_size[0]):
        for j in range(0, height, tile_size[1]):
            tile = image.crop((i, j, i + tile_size[0], j + tile_size[1]))
            tiles.append(tile)

    return tiles

def calculate_l2_distance(image1, image2):
    arr1 = np.asarray(image1).astype(np.float32)
    arr2 = np.asarray(image2).astype(np.float32)
    distance = np.linalg.norm(arr1 - arr2)
    return distance

def max_tile_distance(img1, img2):
    tiles1 = split_image_into_tiles(img1)
    tiles2 = split_image_into_tiles(img2)
    
    max_distance = 0
    for tile1, tile2 in zip(tiles1, tiles2):
        distance = calculate_l2_distance(tile1, tile2)
        if distance > max_distance:
            max_distance = distance

    return max_distance

def measure_small_mse_similarity(gt_images, images):
    x = len(gt_images)
    y = len(images)
    sims = [[0 for _ in range(x)] for _ in range(y)]

    for i in range(len(gt_images)):
        for j in range(len(images)):
            sims[i][j] = float(max_tile_distance(gt_images[i], images[j]))

    return sims
