import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2


class Predictor:
    def __init__(self, model, num_classes, device, mode="batch", scales=None, flip=False):
        assert mode in [
            "batch", "singlescale", "multiscale", "sliding"
        ], "Mode '{}' not supported".format(mode) 
        self.model = model
        self.num_classes = num_classes
        self.device = device
        self.mode = mode
        self.scales = scales
        self.normalize = A.Compose([
            A.Normalize(),
            ToTensorV2()
        ])
        self.flip = flip

    def multi_scale_predict(self, image):
        image = image.squeeze(dim=0)
        if image.size(0) in [1, 3]:
            image = torch.einsum("cij->ijc", image)
        height, width = image.size(0), image.size(1)
        upsample = nn.Upsample(size=(height, width), mode="bilinear", align_corners=True)
        total_predictions = torch.zeros((1, self.num_classes, height, width)).to(self.device)

        image = image.numpy()
        for scale in self.scales:
            resizer = A.Resize(int(height * scale), int(width * scale))
            scaled_image = resizer(image=image)["image"]
            scaled_input = self.normalize(image=scaled_image)["image"].unsqueeze(dim=0).to(self.device)
            scaled_prediction = upsample(self.model(scaled_input))

            if self.flip:
                fliped_input = scaled_input.flip(-1)
                fliped_prediction = upsample(self.model(fliped_input))
                scaled_prediction = 0.5 * (fliped_prediction.flip(-1) + scaled_prediction)
            total_predictions += scaled_prediction

        return total_predictions / len(self.scales)

    def pad_image(self, img, target_size):
        rows_to_pad = max(target_size[0] - img.shape[2], 0)
        cols_to_pad = max(target_size[1] - img.shape[3], 0)
        padded_img = F.pad(img, (0, cols_to_pad, 0, rows_to_pad), "constant", 0)
        return padded_img
    
    def sliding_predict(self, image):
        image = image.squeeze(dim=0)
        height, width = image.size(0), image.size(1)
        tile_size = (height // 2.5, width // 2.5)
        overlap = 1 / 3
        stride = math.ceil(tile_size[0] * (1 - overlap))

        num_rows = int(math.ceil((height - tile_size[0]) / stride) + 1)
        num_cols = int(math.ceil((width - tile_size[1]) / stride) + 1)
        total_predictions = torch.zeros((self.num_classes, height, width)).to(self.device)
        count_predictions = torch.zeros((height, width)).to(self.device)
        tile_counter = 0

        orig_input = self.normalize(image=image)["image"].to(self.device)
        for row in range(num_rows):
            for col in range(num_cols):
                x_min, y_min = int(col * stride), int(row * stride)
                x_max = min(x_min + tile_size[1], width)
                y_max = min(y_min + tile_size[0], height)

                input = orig_input[:, y_min : y_max, x_min : x_max]
                padded_input = self.pad_image(input, tile_size)
                tile_counter += 1
                padded_prediction = self.model(padded_input)
                if self.flip:
                    fliped_input = padded_input.flip(-1)
                    fliped_predictions = self.model(fliped_input)
                    padded_prediction = 0.5 * (fliped_predictions.flip(-1) + padded_prediction)
                predictions = padded_prediction[:, :input.shape[1], :input.shape[2]]
                count_predictions[y_min : y_max, x_min : x_max] += 1
                total_predictions[:, y_min : y_max, x_min : x_max] += predictions

        return total_predictions / count_predictions

    def __call__(self, image : torch.Tensor):
        if self.mode in ["batch", "singlescale"]:
            return self.model(image.to(self.device))
        elif self.mode == "multiscale":
            return self.multi_scale_predict(image)
        else:
            return self.sliding_predict(image)
