import os
import random
import time
from functools import lru_cache
from re import split

from PIL import Image
from typing import Dict


from src.model.single_plant import SinglePlant
from src.utilities import Utility


# TODO parse config file

class Foliage:
    def __init__(self, config):

        self.PLANT_PER_PATCH = config.get("num_plants")
        self.BASE_IMAGE_SIZE = eval(config.get("foliage_size"))
        self.SINGLE_IMAGE_SIZE = config.get("single_plant_size")
        self.OFFSET = config.get("plant_offset")
        self.TYPE = config.get("type")
        self.DISEASE_RATE = config.get("disease_rate")
        self.config = config
        self.grid = (4,3) # row x col
        self.BASE_BACKGROUND_IMAGE = config.get("background_image_path")
        self._setup()

    def _setup(self):
        self._image_cache: Dict[str, Image.Image] = {}
        self.single_plant = SinglePlant(self.config)
        self.utility = Utility()
        self.background_paths = self.__get_list_of_background_images()
    def __get_list_of_background_images(self):
        """Pre-load all background image paths"""
        return [
            os.path.join(self.BASE_BACKGROUND_IMAGE, img)
            for img in os.listdir(self.BASE_BACKGROUND_IMAGE)
            if img.endswith('.png')
        ]

    # TODO unify the cache from multiple classes
    # create an image utility class to perform all image related common tasks from single class
    @lru_cache(maxsize=128)
    def _load_and_prepare_image(self, image_path: str) -> Image.Image:
        if image_path not in self._image_cache:
            img = Image.open(image_path).convert("RGBA")
            img = img.resize(self.BASE_IMAGE_SIZE, Image.Resampling.LANCZOS)
            self._image_cache[image_path] = img
        return self._image_cache[image_path].copy()

    def _get_coordinates_for_single_plant(self):

        column_gap = 2
        if self.TYPE == "tomato":
            column_gap = 2

        x_coord = self.BASE_IMAGE_SIZE[0] // 3 - self.SINGLE_IMAGE_SIZE // 2 - self.OFFSET
        y_coord_step = self.SINGLE_IMAGE_SIZE - (self.OFFSET * 2)

        coords = []
        for y in range(0, self.BASE_IMAGE_SIZE[1], y_coord_step):
            coords.append((x_coord, y))
            coords.append((x_coord + (self.SINGLE_IMAGE_SIZE - int(self.OFFSET * column_gap)), y - self.OFFSET))
            coords.append((x_coord + 2 * (self.SINGLE_IMAGE_SIZE - int(self.OFFSET * column_gap)), y - self.OFFSET))
        return coords

    def get_patch_indices(self, cluster_id, rows=4, cols=3, patch_size=2):
        """
        Returns the flat indices of a patch in a 1D array representing a 2D grid.

        Parameters:
            cluster_id (int): Number of the patch (starting from 1)
            rows (int): Number of rows in the grid
            cols (int): Number of columns in the grid
            patch_size (int): Size of the square patch (default is 2 for 2x2)

        Returns:
            List[int]: Flat indices of the patch elements
        """
        max_row_start = rows - patch_size
        max_col_start = cols - patch_size
        total_patches = (max_row_start + 1) * (max_col_start + 1)

        if cluster_id < 1 or cluster_id > total_patches:
            raise ValueError(f"Patch number must be between 1 and {total_patches}")

        # Map patch_number to its (row, col) starting position

        row_id = (cluster_id - 1) // (max_col_start + 1)
        col_id = (cluster_id - 1) % (max_col_start + 1)

        indices = []
        for dy in range(patch_size):
            for dx in range(patch_size):
                r = row_id + dy
                c = col_id + dx
                flat_index = r * cols + c
                indices.append(flat_index)

        # the indices start from 1 horizontally
        return indices

    def get_neighbors(self, index, rows=4, cols=3):
        neighbors = []

        # Convert flat index -> (row, col)
        row, col = divmod(index, cols)

        # Up
        if row > 0:
            neighbors.append(index - cols)
        # Down
        if row < rows - 1:
            neighbors.append(index + cols)
        # Left
        if col > 0:
            neighbors.append(index - 1)
        # Right
        if col < cols - 1:
            neighbors.append(index + 1)

        return neighbors


    def get_cluster_coords_index(self):
        # returns a tuple containing list of neighbors and hotspot index
        hotspot = random.randint(1, 11)
        neighbors = self.get_neighbors(hotspot)
        return neighbors, hotspot


    # def split_disease_rate(self, disease_rate, num_cut=4):
    #     # Generate random cut points
    #     cuts = sorted(random.sample(range(1, disease_rate), num_cut - 1))
    #     # Add start and end boundaries
    #     cuts = [0] + cuts + [disease_rate]
    #     # Take differences
    #     result = [cuts[i + 1] - cuts[i] for i in range(num_cut)]
    #     return result

    def get_num_hotspots(self,disease_rate, threshold = 10 ):
        if disease_rate > threshold:
            return self.utility.get_biased_random_number(range=[1,2], weights=[0.2, 0.8])
        else:
            return self.utility.get_biased_random_number(range=[1,2], weights=[0.8, 0.2])

    def get_cluster_idx_and_their_disease_rate(self, disease_rate):
        # number of clusters
        num_clusters = self.get_num_hotspots(disease_rate)

        cluster_indexes = []
        disease_rates = []
        cluster_coord_index = []
        for _ in range(num_clusters):
            cluster_index, hotspot = self.get_cluster_coords_index()
            cluster_indexes.extend(cluster_index)
            cluster_indexes.append(hotspot)

            # accounting for the whole foliage, the disease rate will increase for a single cluster
            disease_rate = disease_rate * 12 // (len(cluster_index) + 1)

            split_disease_rate = self.utility.get_normalized_disease_rates(disease_rate, len(cluster_index) + 1,
                                                                           normalize=True)

            # moving maximum disease rate to the last so that hotspot has max percentage
            max_val = max(split_disease_rate)
            split_disease_rate.remove(max_val)  # remove first occurrence
            split_disease_rate.append(max_val)

            disease_rates.extend(split_disease_rate)
        updated_list = list(set(cluster_indexes))
        return updated_list, disease_rates[:len(updated_list)]


    def get_patch_of_leaves(self, disease_rate, disease="healthy") -> Image:

        # angles = self.utility.get_random_angle(self.PLANT_PER_PATCH)
        coords = self._get_coordinates_for_single_plant()

        cluster_coord_index, split_disease_rates = self.get_cluster_idx_and_their_disease_rate(disease_rate)

        background_image = self._load_and_prepare_image(random.choice(self.background_paths))
        background_image = background_image.resize(self.BASE_IMAGE_SIZE)

        disease_rate_counter= 0

        for idx, coord in enumerate(coords):

            if idx in cluster_coord_index:
                single_plant = self.single_plant.get_single_plant(split_disease_rates[disease_rate_counter], disease)
                disease_rate_counter +=1
            else:
                single_plant = self.single_plant.get_single_plant()
            # single_plant = single_plant.rotate(angle)
            background_image.paste(single_plant, coord, single_plant)
        return background_image

if __name__ == '__main__':

    utility = Utility()
    config =  utility.json_parser("/Users/C00540403/Documents/research/Foliagen/FoliageGenerator/src/soybean/config.json")
    trifoliate_patch = Foliage(config)

    start_time = time.time()
    patch = trifoliate_patch.get_patch_of_leaves("bacterial_blight")
    patch.show()
    end_time = time.time()
    print("Total time taken: ", end_time - start_time)