import torch.nn as nn
import torchvision.transforms as T

import os
import numpy as np

from lightly.transforms.simsiam_transform import SimSiamTransform
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.transforms.byol_transform import BYOLTransform

class DefaultTransform():

    def __init__(self, img_size, mean, std):
        self.transform = T.Compose([T.Resize((img_size, img_size)), 
                                    T.ToTensor(), T.Normalize(mean, std)])
        
    def __call__(self, x):
        return self.transform(x)
    
class MultiViewTransform:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image):
        return [transform(image) for transform in self.transforms]
    
class DummyTransform:
    def __init__(self, label):
        self.label = label
    def __call__(self, x):
        print(os.getpid())
        print(self.label)
        print(type(x))
        return x

class ContrastiveTransform(MultiViewTransform):

    def __init__(self, img_size, mean, std, config):

        self.n_views = config.n_views

        self.transform = T.Compose([
            #DummyTransform('Raw'),
            T.RandomResizedCrop(size=img_size, scale=(config.min_crop, config.max_crop)),
            #DummyTransform('Resize'),
            T.RandomHorizontalFlip(p=config.hf),
            #DummyTransform('Flip'),
            #T.RandomVerticalFlip(p=config.vf),
            T.RandomApply([T.ColorJitter(
                brightness=config.cj_b,
                contrast=config.cj_c,
                saturation=config.cj_s,
                hue=config.cj_h,
            )], p=config.cj),
            #DummyTransform('Color Jitter'),
            T.RandomGrayscale(p=config.gs),
            #DummyTransform('GrayScale'),
            T.RandomApply([T.GaussianBlur(kernel_size=config.gb_kernel, 
                                         sigma=(config.sigma1, config.sigma2))], 
                          p=config.gb),
            #DummyTransform('Gaussian'),
            T.ToTensor(),
            #DummyTransform('Tensor'),
            T.Normalize(mean, std)
        ])

        super().__init__(transforms=[self.transform for _ in range(self.n_views)])
    
class FromNumpyMultiViewTransform(MultiViewTransform):

    def __init__(self, transform):
        transforms = [T.Compose([T.ToPILImage()] + transform.transforms) for transform in transform.transforms]
        super().__init__(transforms=transforms)

class FromNumpyDefaultTransform():
    def __init__(self, img_size, mean, std):
        self.transform = T.Compose([T.ToPILImage(),
                                    T.Resize((img_size, img_size)),
                                    T.ToTensor(),
                                    T.Normalize(mean, std)])
    def __call__(self, x):
        return self.transform(x)

class PassThroughTransform():
    def __init__(self):
        pass
    def __call__(self, x):
        return x

class ToNumpy:
    def __init__(self):
        pass
    def __call__(self, x):
        return np.array(x).astype(np.uint8)  
    
class NumpyRawTransform:

    def __init__(self, img_size, alter_raw=False):
        if not alter_raw:
            self.transform = T.Compose([
                T.Resize((img_size, img_size)),
                ToNumpy()
            ])
        else:
            print('new stuff working')
            self.transform = T.Compose([
                T.RandomResizedCrop(img_size, (0.9, 1)),
                T.ColorJitter(
                brightness=0.05,
                contrast=0.05,
                saturation=0.05,
                hue=0.02,
                ),
                ToNumpy()
            ])
    def __call__(self, x):
        return self.transform(x)
    
class DAATransform(MultiViewTransform):

    def __init__(self, img_size, mean, std, config):

        self.n_views = config.n_views

        self.base_transform = T.Compose([
            T.Resize((img_size, img_size)),
            T.ToTensor(),
            T.Normalize(mean, std)
        ])

        self.transform = T.Compose([
            T.RandomResizedCrop(size=img_size, scale=(config.min_crop, config.max_crop)),
            T.RandomHorizontalFlip(p=config.hf),
            T.RandomApply([T.ColorJitter(
                brightness=config.cj_b,
                contrast=config.cj_c,
                saturation=config.cj_s,
                hue=config.cj_h,
            )], p=config.cj),
            T.RandomGrayscale(p=config.gs),
            T.RandomSolarize(config.solar_thresh, config.p_solar),
            T.ToTensor(),
            T.Normalize(mean, std)
        ])


        super().__init__(transforms=[self.base_transform] + [self.transform for _ in range(self.n_views)])

    def __call__(self, image):
        return [transform(image) for transform in self.transforms]

class ExtractPatches(nn.Module):
    def __init__(self, patch_size, stride):
        super(ExtractPatches, self).__init__()
        self.patch_size = patch_size
        self.stride = stride

    def forward(self, x):
        # Get dimensions of input tensor
        channels, height, width = x.size()

        # Calculate number of patches
        num_patches_height = (height - self.patch_size) // self.stride + 1
        num_patches_width = (width - self.patch_size) // self.stride + 1

        # Unfold the input tensor to extract patches
        patches = x.unfold(1, self.patch_size, self.stride).unfold(2, self.patch_size, self.stride)
        patches = patches.permute(1, 2, 0, 3, 4).contiguous().view(-1, channels, self.patch_size, self.patch_size)

        return patches

class PCMCTransform(MultiViewTransform):

    def __init__(self, img_size, mean, std, patch_size, config):

        self.n_views = config.n_views

        self.transform = T.Compose([
            T.RandomResizedCrop(size=img_size, scale=(config.min_crop, config.max_crop)),
            T.RandomHorizontalFlip(p=config.hf),
            T.RandomVerticalFlip(p=config.vf),
            T.RandomApply([T.ColorJitter(
                brightness=config.cj_b,
                contrast=config.cj_c,
                saturation=config.cj_s,
                hue=config.cj_h,
            )], p=config.cj),
            T.RandomGrayscale(p=config.gs),
            T.RandomApply([T.GaussianBlur(kernel_size=config.gb_kernel, 
                                         sigma=(config.sigma1, config.sigma2))], 
                          p=config.gb),
            T.ToTensor(),
            ExtractPatches(patch_size, patch_size),
            T.Normalize(mean, std)
        ])

        super().__init__(transforms=[self.transform for _ in range(self.n_views)])

    
def load_transform(transform_type, img_size, mean, std, config):

    if transform_type == 'simclr' or transform_type == 'simsiam' or transform_type == 'vicreg':
        return ContrastiveTransform(img_size,
                                  mean, 
                                  std, 
                                  config)
    elif transform_type == 'daa':
        return DAATransform(img_size,
                            mean, std,
                            config)
    elif transform_type == 'pcmc':
        return PCMCTransform(img_size,
                             mean, std,
                             config.patch_size,
                             config)
    else:
        return DefaultTransform(img_size,
                            mean,
                            std)
    


if __name__ == '__main__':
    class Config:
        class Agent:
            def __init__(self):
                self.cj = 0.8
                self.cj_b = 0.4
                self.cj_c = 0.4
                self.cj_s = 0.4
                self.cj_h = 0.1
                self.n_views = 3

        def __init__(self):
            self.agent = self.Agent()

    config = Config()
    t = DAATransform(128, [0.1, 0.1, 0.1], [0.1, 0.1, 0.1], config)

    print(t.transforms)

