import json
import logging
import numbers
import os
import random
import secrets
import shutil
import sys
from collections.abc import Sequence
from glob import glob
from typing import Tuple

import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from config import Constants


def layer_maker(cfg, in_channels=1, conv_kernel_size=3, up_kernel_size=3, batch_norm=False, dilation=1):
    # Make the convolution layers
    layers = []
    for v in cfg[0]:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        elif v == 'U':
            layers += [nn.Upsample(scale_factor=2, mode='bilinear')]
        elif v == 'R':
            layers += [nn.ReLU(inplace=True)]
        elif isinstance(v, int):
            conv2d = nn.Conv2d(in_channels, v,
                               kernel_size=conv_kernel_size,
                               padding=dilation,
                               dilation=dilation)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v)]
            else:
                layers += [conv2d]
            in_channels = v

    for v in cfg[1]:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        elif v == 'U':
            layers += [nn.Upsample(scale_factor=2, mode='bilinear')]
        elif v == 'R':
            layers += [nn.ReLU(inplace=True)]
        elif isinstance(v, int):
            conv2d = nn.Conv2d(in_channels, v, kernel_size=up_kernel_size, padding=dilation, dilation=dilation)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v)]
            else:
                layers += [conv2d]

            in_channels = v

    return nn.Sequential(*layers)

def random_segmentation(size):
    rand_mask = np.random.normal(0, 1, size)
    rand_mask[np.where(rand_mask > 0)] = 1
    rand_mask[np.where(rand_mask < 0)] = 0
    return rand_mask


class RandomRotation90(transforms.RandomRotation):
    def get_params(self, img):
        # Choose a random rotation angle in multiples of 90 degrees
        angle = random.randint(0, 3) * 90

        # Return the rotation angle as a tuple
        return angle

def random_crop(img_size, crop_size):
    img_height = img_size[0]
    img_width = img_size[1]
    crop_height = crop_size[0]
    crop_width = crop_size[1]
    res_height = img_height - crop_height
    res_width = img_width - crop_width
    i = random.randint(0, res_height)
    j = random.randint(0, res_width)
    return i, j, crop_height, crop_width

def get_image_filename_list(root) -> list:

    if os.path.exists(root) is not True:
        raise FileNotFoundError("{} does not exist.".format(root))

    img_dir = os.path.join(root, "raw")
    dot_dir = os.path.join(root, "dot")
    if os.path.exists(img_dir) is not True or os.path.exists(dot_dir) is not True:
        raise Exception("Unknown dataset structure {}".format(root))
    # Validate the file structure
    try:
        json_file = os.path.join(root, glob("*.json", root_dir=root)[0])
    except Exception:
        raise FileNotFoundError("Cannot find json file of filenames in {}".format(root))

    with open(json_file, 'r') as f:
        image_file_list = json.load(f)

    return image_file_list

@staticmethod
def _setup_size(size, error_msg):
    if size is None:
        return None

    if isinstance(size, numbers.Number):
        return int(size), int(size)

    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]

    if len(size) != 2:
        raise ValueError(error_msg)

    return size


class AverageMeter(object):
    def __init__(self) -> None:
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0
        self.min = 1e+4

    def update(self, value, n = 1):
        self.val = value
        self.sum += value * n
        self.count += n
        self.avg = self.sum / self.count
        self.min = min(self.min, value)

    def get(self, key):
        result = {'sum': self.sum, 'count': self.count, 'avg': self.avg, 'min': self.min}
        return result[key]


class ModelSelection(object):

    records = {}
    metrics = None
    early_stop_counter = 0

    def __init__(self, optimal_save: bool = True, optimal_key: str = None, reversed: bool = False, early_stopping: bool = False, patience: int = 10):
        """ModelSelection

        Args:
            optimal_save (bool, optional): Only save the optimal result to save the disk space. Defaults to True.
            optimal_key (str, optional): The key in metrics for opting optimal record. Defaults to None.
            reversed (bool, optional): Default is the less result the better, Reversed indicates opting the larger record. Defaults to False.
        """

        if os.path.exists(Constants.CHECK_POINT) is False:
            os.mkdir(Constants.CHECK_POINT)
        self.optimal_save = optimal_save
        self.optimal_key: str = optimal_key
        self.optimal_reverse = reversed

        if early_stopping:
            self.stopping_patience = patience
        else:
            self.stopping_patience = None

        if optimal_save:
            assert optimal_key is not None
            self.optimal_record = None

    def save(self, model: nn.Module, optimizer: torch.optim.Optimizer, info: dict):
        filename = secrets.token_hex(4)
        self.records[filename] = info
        optimal: bool = False
        init_save: bool = False

        if self.metrics is None:
            self.metrics: list = list(info.keys())
        if self.optimal_record is None:
            init_save: bool = True
        else:
            best = self.optimal_record[self.optimal_key]
            current = info[self.optimal_key]
            if self.optimal_reverse:
                best = - best
                current = - current

            if best > current:
                optimal = True
                self.early_stop_counter = 0
            else:
                self.early_stop_counter += 1
                if self.stopping_patience is not None and self.early_stop_counter >= self.stopping_patience:
                    raise StopIteration(
                        "Early stopping is triggered")

        if init_save or optimal:
            info['filename'] = filename
            self.optimal_record = info
            info['model_state_dict'] = model.state_dict()
            info['optimizer_dict'] = optimizer.state_dict()
            torch.save(info, os.path.join(
                Constants.CHECK_POINT, info['filename'] + ".pt"))

    def get_records(self):
        return self.records

    def load(self,  model: nn.Module, optimizer: torch.optim.Optimizer, device: torch.device, identifier: str = None) -> Tuple[nn.Module, torch.optim.Optimizer, dict]:
        if self.optimal_save:
            identifier = self.optimal_record['filename']
        else:
            if identifier not in self.records.keys():
                raise KeyError(
                    "Cannot find the required checkpoint of {}.".format(identifier))

        checkpoint: dict = torch.load(
            os.path.join(Constants.CHECK_POINT, identifier + ".pt"), map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_dict'])
        del checkpoint['model_state_dict']
        del checkpoint['optimizer_dict']
        return model, optimizer, checkpoint

    def clean(self, keep_optimal: bool = False):
        clean_list: list[str] = list(self.records.keys())
        if keep_optimal:
            clean_list.pop(clean_list.index(self.optimal_record['filename']))
        for filename in clean_list:
            if os.path.exists(os.path.join(Constants.CHECK_POINT, filename + ".pt")):
                os.remove(os.path.join(
                    Constants.CHECK_POINT, filename + ".pt"))


# Logging out module
class Model_Logger(logging.Logger):

    _instance = None

    def __init__(self, name, level = logging.INFO):
        super(Model_Logger, self).__init__(name, level)
        self.name = name
        self.level = level
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        # logging basic settings

        fileHandler = LocalFileHandler(
            filename=os.path.join(Constants.LOG_FOLDER, "{}.log".format(Constants.LOG_NAME)),
            mode='a+')
        fileHandler.setFormatter(formatter)
        fileHandler.setLevel(logging.DEBUG)
        # File handler

        streamHandler = logging.StreamHandler()
        streamHandler.setFormatter(formatter)
        streamHandler.setLevel(self.level)
        # Stream handler

        self.addHandler(fileHandler)
        self.addHandler(streamHandler)
        self.enable_exception_hook()
        # Add handler

    def exception_hook(self, exc_type, exc_value, exc_traceback):
        self.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))

    def enable_exception_hook(self):
        sys.excepthook = self.exception_hook
        # Set the exception hook


class LocalFileHandler(logging.FileHandler):
    def __init__(self, filename: str,
                 mode: str = "a",
                 encoding: str | None = None,
                 delay: bool = False,
                 errors: str | None = None) -> None:
        super().__init__(filename, mode, encoding, delay, errors)

def loss_less_dot_resize(dot, size):
    dot_coords = torch.nonzero(dot > 0).float()
    h, w = dot.shape[-2:]
    new_height, new_width = size
    scaling_factor_width = new_width / w


    scaling_factor_height = new_height / h

    # Scale dot coordinates
    scaled_dot_coords = dot_coords.clone()
    scaled_dot_coords[:, 0] *= scaling_factor_height  # y-coordinate
    scaled_dot_coords[:, 1] *= scaling_factor_width   # x-coordinate

    # Initialize empty dot map
    new_dot_map_tensor = torch.zeros((new_height, new_width), dtype=torch.uint8)

    # Round the coordinates and convert to integers
    scaled_dot_coords = torch.round(scaled_dot_coords).long()

    # Ensure coordinates are within the image boundaries
    mask = (scaled_dot_coords[:, 0] >= 0) & (scaled_dot_coords[:, 0] < new_height) & \
        (scaled_dot_coords[:, 1] >= 0) & (scaled_dot_coords[:, 1] < new_width)
    scaled_dot_coords = scaled_dot_coords[mask]

    # Plot the dots on the new dot map
    new_dot_map_tensor[scaled_dot_coords[:, 0], scaled_dot_coords[:, 1]] = 1.
    return new_dot_map_tensor.float().unsqueeze(0).unsqueeze(0) # Make it a channel and batch





