'''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.'''

from abc import ABC, abstractmethod
from functools import partial
import yaml

import torch.nn as nn
from torch.nn import functional as F

import torchvision

import numpy as np
import torch
from math import pi

import scipy

from glob import glob
from PIL import Image
from typing import Callable, Optional
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset
import os 
from torchvision import transforms
from roomlayout import LayoutSeg
from roomsegmentation import SegmentationModule, ModelBuilder

from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from roomtext import blip_itm, blip_decoder
from gg18 import ScaleHyperpriorSTE
from torchvision.models import resnet50, ResNet50_Weights

class Resizer(nn.Module):
    def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True):
        super(Resizer, self).__init__()
        scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor)
        method, kernel_width = {
            "cubic": (cubic, 4.0),
            "lanczos2": (lanczos2, 4.0),
            "lanczos3": (lanczos3, 6.0),
            "box": (box, 1.0),
            "linear": (linear, 2.0),
            None: (cubic, 4.0)  # set default interpolation method as cubic
        }.get(kernel)

        antialiasing *= (np.any(np.array(scale_factor) < 1))

        sorted_dims = np.argsort(np.array(scale_factor))
        self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1]

        field_of_view_list = []
        weights_list = []
        for dim in self.sorted_dims:
            weights, field_of_view = self.contributions(in_shape[dim], output_shape[dim], scale_factor[dim], method,
                                                        kernel_width, antialiasing)

            weights = torch.tensor(weights.T, dtype=torch.float32)

            weights_list.append(
                nn.Parameter(torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]),
                             requires_grad=False))
            field_of_view_list.append(
                nn.Parameter(torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False))

        self.field_of_view = nn.ParameterList(field_of_view_list)
        self.weights = nn.ParameterList(weights_list)

    def forward(self, in_tensor):
        x = in_tensor
        for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights):
            x = torch.transpose(x, dim, 0)
            x = torch.sum(x[fov] * w, dim=0)
            x = torch.transpose(x, dim, 0)
        return x

    def fix_scale_and_size(self, input_shape, output_shape, scale_factor):
        if scale_factor is not None:
            if np.isscalar(scale_factor) and len(input_shape) > 1:
                scale_factor = [scale_factor, scale_factor]
            scale_factor = list(scale_factor)
            scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor
        if output_shape is not None:
            output_shape = list(input_shape[len(output_shape):]) + list(np.uint(np.array(output_shape)))
        if scale_factor is None:
            scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
        if output_shape is None:
            output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
        return scale_factor, output_shape

    def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing):
        fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
        kernel_width *= 1.0 / scale if antialiasing else 1.0
        out_coordinates = np.arange(1, out_length + 1)
        shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
        match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale)
        left_boundary = np.floor(match_coordinates - kernel_width / 2)
        expanded_kernel_width = np.ceil(kernel_width) + 2
        field_of_view = np.squeeze(
            np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1))
        weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
        sum_weights = np.sum(weights, axis=1)
        sum_weights[sum_weights == 0] = 1.0
        weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
        mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
        field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
        non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
        weights = np.squeeze(weights[:, non_zero_out_pixels])
        field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
        return weights, field_of_view


def cubic(x):
    absx = np.abs(x)
    absx2 = absx ** 2
    absx3 = absx ** 3
    return ((1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) +
            (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((1 < absx) & (absx <= 2)))


def lanczos2(x):
    return (((np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) /
             ((pi ** 2 * x ** 2 / 2) + np.finfo(np.float32).eps))
            * (abs(x) < 2))


def box(x):
    return ((-0.5 <= x) & (x < 0.5)) * 1.0


def lanczos3(x):
    return (((np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) /
             ((pi ** 2 * x ** 2 / 3) + np.finfo(np.float32).eps))
            * (abs(x) < 3))


def linear(x):
    return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))

__OPERATOR__ = {}

def register_operator(name: str):
    def wrapper(cls):
        if __OPERATOR__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __OPERATOR__[name] = cls
        return cls
    return wrapper


def get_operator(name: str, **kwargs):
    if __OPERATOR__.get(name, None) is None:
        raise NameError(f"Name {name} is not defined.")
    return __OPERATOR__[name](**kwargs)


class LinearOperator(ABC):
    @abstractmethod
    def forward(self, data, **kwargs):
        # calculate A * X
        pass

    @abstractmethod
    def transpose(self, data, **kwargs):
        # calculate A^T * X
        pass
    
    def ortho_project(self, data, **kwargs):
        # calculate (I - A^T * A)X
        return data - self.transpose(self.forward(data, **kwargs), **kwargs)

    def project(self, data, measurement, **kwargs):
        # calculate (I - A^T * A)Y - AX
        return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)


@register_operator(name='super_resolution')
class SuperResolutionOperator(LinearOperator):
    def __init__(self, in_shape, scale_factor, device):
        self.device = device
        self.up_sample = partial(F.interpolate, scale_factor=scale_factor)
        self.down_sample = Resizer(in_shape, 1/scale_factor).to(device)

    def forward(self, data, **kwargs):
        return self.down_sample(data)

    def transpose(self, data, **kwargs):
        return self.up_sample(data)

    def project(self, data, measurement, **kwargs):
        return data - self.transpose(self.forward(data)) + self.transpose(measurement)


@register_operator(name='roomlayout')
class Layout(LinearOperator):
    def __init__(self, weight_path, device):
        self.model = LayoutSeg.load_from_checkpoint(weight_path, backbone='resnet101').to(device)
        self.model.freeze()

    def forward(self, data, **kwargs):
        scores, _ = self.model(data)
        assert 'mode' in kwargs
        if kwargs['mode'] == 'init':
            return torch.argmax(scores, dim=1, keepdim=True)
        else:
            return scores

    def transpose(self, data):
        return data
    
@register_operator(name='roomsegmentation')
class Segmentation(LinearOperator):
    def __init__(self,device):
        self.encoder = ModelBuilder.build_encoder(arch="mobilenetv2dilated",fc_dim=320,weights="./bins/encoder_epoch_20_A.pth").to('cuda')
        self.decoder = ModelBuilder.build_decoder(arch="c1_deepsup",fc_dim=320,num_class=150,weights="./bins/decoder_epoch_20_A.pth",use_softmax=True).to('cuda')
        for name, param in self.encoder.named_parameters():
            param.requires_grad = False
        for name, param in self.decoder.named_parameters():
            param.requires_grad = False
        self.transform = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
    
    def forward(self, data, **kwargs):
        data = (data + 1) / 2.0
        data = self.transform(data)
        pred = self.decoder(self.encoder(data, return_feature_maps=True), segSize=(256,256))
        assert 'mode' in kwargs
        if kwargs['mode'] == 'init':
            return torch.argmax(pred, dim=1, keepdim=True)
        else:
            return pred
    
    def transpose(self, data):
        return data

@register_operator(name='roomtext')
class Image2Text(LinearOperator):
    def __init__(self,device) -> None:
        self.normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        self.transform_test = transforms.Compose([transforms.Resize((384, 384),interpolation=InterpolationMode.BICUBIC),self.normalize,])
        self.itm_model = blip_itm(pretrained='./bins/model_base_retrieval_coco.pth', image_size=384, vit='base')
        self.itm_model.eval()
        self.itm_model = self.itm_model.to(device='cuda')
        self.blip_decoder_model = blip_decoder(pretrained='./bins/model_base_caption_capfilt_large.pth', image_size=384, vit='base')
        self.blip_decoder_model.eval()
        self.blip_decoder_model = self.blip_decoder_model.to(device='cuda')
    def forward(self, data, **kwargs):
        data = (data + 1.0)/2
        data = self.transform_test(data)
        if kwargs['mode'] == 'init':
            return self.blip_decoder_model.generate(data, sample=False, num_beams=3, max_length=20, min_length=5)
        else:
            itm_output = self.itm_model(data,kwargs['caption'],match_head='itm')
            itm_score = torch.nn.functional.softmax(itm_output,dim=1)[:,1]
            return - itm_score
        
    def transpose(self,data):
        return None


@register_operator(name='catcls2')
class CatClassification2(LinearOperator):
    def __init__(self, device):
        self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
        self.model.eval()
        self.normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
        self.transform_test = transforms.Compose([transforms.Resize((224, 224),interpolation=InterpolationMode.BICUBIC),self.normalize,])
    def forward(self, data, **kwargs):
        data = (data + 1.0) / 2.0
        data = self.transform_test(data)
        uscore = self.model(data)
        score = F.softmax(uscore, dim=1)
        if kwargs['mode'] == 'init':
            index = torch.argmax(score, dim=1, keepdim=True)
            return index
        else:
            return score
        
    def transpose(self,data):
        return None


__DATASET__ = {}

def register_dataset(name: str):
    def wrapper(cls):
        if __DATASET__.get(name, None):
            raise NameError(f"Name {name} is already registered!")
        __DATASET__[name] = cls
        return cls
    return wrapper


def get_dataset(name: str, root: str, **kwargs):
    if __DATASET__.get(name, None) is None:
        raise NameError(f"Dataset {name} is not defined.")
    return __DATASET__[name](root=root, **kwargs)


def get_dataloader(dataset: VisionDataset,
                   batch_size: int, 
                   num_workers: int, 
                   train: bool):
    dataloader = DataLoader(dataset, 
                            batch_size, 
                            shuffle=train, 
                            num_workers=num_workers, 
                            drop_last=train)
    return dataloader


@register_dataset(name='ffhq')
class FFHQDataset(VisionDataset):
    def __init__(self, root: str, transforms: Optional[Callable]=None):
        super().__init__(root, transforms)

        self.fpaths = sorted(glob(root + '/**/*.png', recursive=True))
        assert len(self.fpaths) > 0, "File list is empty. Check the root."

    def __len__(self):
        return min(len(self.fpaths), 1000)

    def __getitem__(self, index: int):
        fpath = self.fpaths[index]
        img = Image.open(fpath).convert('RGB')
        
        if self.transforms is not None:
            img = self.transforms(img)
        
        return img

@register_dataset(name='lsunlayout')
class LSUNLayout(VisionDataset):
    def __init__(self, root: str, transforms: Optional[Callable]=None):
        super().__init__(root, transforms)

        self.fpaths = sorted(glob(root + '/**/*.png', recursive=True))
        assert len(self.fpaths) > 0, "File list is empty. Check the root."

    def __len__(self):
        return min(len(self.fpaths), 1000)

    def __getitem__(self, index: int):
        fpath = self.fpaths[index]
        img = Image.open(fpath)
        
        if self.transforms is not None:
            img = self.transforms(img)
        img = torch.round(img * 5) - 1
        return img # [0,1,2,3,4]
