import math
import os
import numpy as np
from PIL import Image
from torchvision import datasets
from torchvision import transforms
from .randaugment import *

class TransformFixMatch(object):
    def __init__(self, mean, std, size_image=32, **kwargs):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=size_image,
                                  padding=int(size_image*0.125),
                                  padding_mode='reflect'),
                                  transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        self.weak2 = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=size_image,
                                  padding=int(size_image*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return weak, strong
        
class TransformCausal(object):
    def __init__(self, mean, std, size_image=32):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=size_image,
                                  padding=int(size_image*0.125),
                                  padding_mode='reflect'),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

        self.pre_strong = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(size=size_image,
                                  padding=int(size_image*0.125),
                                  padding_mode='reflect'),])
        self.ops = MyRandAugmentMC(n=1, m=10)
        self.post_strong = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        idx = np.random.randint(0, 4)
        if idx == 0:
            weak = self.weak(x)
            return weak, idx
        else:
            strong = self.pre_strong(x)
            strong = self.ops(strong, idx)
            strong = self.post_strong(strong)
            return strong, idx
        


class TransformConsistency(object):
    def __init__(self, mean, std, size_image=32, **kwargs):
        self.ident = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=size_image,
                                  padding=int(size_image*0.125),
                                  padding_mode='reflect'),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
            ])


    def __call__(self, x):
        weak = self.weak(x)
        return weak

class MyRandAugmentMC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = causal_augment_pool()

    def __call__(self, img, idx):
        op, max_v, bias = self.augment_pool[idx]
        v = self.m
        img = op(img, v=v, max_v=max_v, bias=bias)
        img = CutoutAbs(img, int(32*0.5))
        return img

class TransformTest(object):
    def __init__(self, mean, std):
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])

    def __call__(self, x):
        x_test = self.transform(x)
        return x_test