import torch
import torch.nn as nn
import torchvision.transforms as transforms
import random

class ConvMix(nn.Module):
    def __init__(self, args):
        super(ConvMix, self).__init__()
        if args.dataset == 'Digits':
            self.norm = transforms.Normalize([0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        else:
            self.norm = transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.args = args
        
    def flip(self, x, a, b):
        normalized = (x - a) / (b - a)
        flipped_normalized = 1 - normalized
        flipped_x = a + flipped_normalized * (b - a)
        return flipped_x

    def forward(self, x, k_min=1, k_max=15, p_min=0.1, p_max=0.3, N=3):
        
        assert x.size(1) == 3
        if self.args.dataset == 'Digits':
            ch_max = torch.tensor([1, 1, 1]).view(1, 3, 1, 1).to(x.device)
            ch_min = torch.tensor([-1, -1, -1]).view(1, 3, 1, 1).to(x.device)
        else:
            ch_max = torch.tensor([(1-0.485)/0.229, (1-0.456)/0.224, (1-0.406)/0.225]).view(1, 3, 1, 1).to(x.device)
            ch_min = torch.tensor([(-0.485)/0.229, (-0.456)/0.224, (-0.406)/0.225]).view(1, 3, 1, 1).to(x.device)
        rand_nums = torch.rand(x.size(0), 1, 1, 1).to(x.device)
        condition = (rand_nums > 0.8) #lower threshold, more flip
        x = torch.where(condition, self.flip(x, ch_min, ch_max), x)
        for i in range(N):
            x += torch.randn_like(x) * 0.01
            k = random.randint(k_min, k_max)//2 * 2 + 1
            p = random.uniform(p_min, p_max)
            d = random.randint(1, 2)
            conv = nn.Conv2d(3, 3, k, padding=int((k+(k-1)*(d-1)-1)/2), dilation=d).cuda()
            nn.init.kaiming_uniform_(conv.weight)

            Tx = conv(x)
            Tx = torch.tanh(Tx)
            Tx = self.norm(Tx)

            if self.args.dataset == 'Digits':
                ch_max = torch.tensor([1, 1, 1]).view(1, 3, 1, 1).to(x.device)
                ch_min = torch.tensor([-1, -1, -1]).view(1, 3, 1, 1).to(x.device)
            else:
                ch_max = torch.tensor([(1-0.485)/0.229, (1-0.456)/0.224, (1-0.406)/0.225]).view(1, 3, 1, 1).to(x.device)
                ch_min = torch.tensor([(-0.485)/0.229, (-0.456)/0.224, (-0.406)/0.225]).view(1, 3, 1, 1).to(x.device)
            rand_nums = torch.rand(x.size(0), 1, 1, 1).to(x.device)
            condition = (rand_nums > 0.8) #lower threshold, more flip
            Tx = torch.where(condition, self.flip(Tx, ch_min, ch_max), Tx)    

            x = p * Tx + (1-p) * x

        x = torch.sigmoid(x)
        x = self.norm(x)

        if self.args.dataset == 'Digits':
            ch_max = torch.tensor([1, 1, 1]).view(1, 3, 1, 1).to(x.device)
            ch_min = torch.tensor([-1, -1, -1]).view(1, 3, 1, 1).to(x.device)
        else:
            ch_max = torch.tensor([(1-0.485)/0.229, (1-0.456)/0.224, (1-0.406)/0.225]).view(1, 3, 1, 1).to(x.device)
            ch_min = torch.tensor([(-0.485)/0.229, (-0.456)/0.224, (-0.406)/0.225]).view(1, 3, 1, 1).to(x.device)
        rand_nums = torch.rand(x.size(0), 1, 1, 1).to(x.device)
        condition = (rand_nums > 0.8)
        x = torch.where(condition, self.flip(x, ch_min, ch_max), x)    
        return x