import torch
from torchvision import transforms
import numpy as np
from torchvision.transforms import ToPILImage
import random
class data_augmentation():
    def __init__(self, device='cuda'):
        self.device = device
        
    def save_tensor_as_image(self, tensor, filename):
        tensor = tensor.clone().detach() 
        tensor = tensor.to('cpu')  
        if tensor.shape[0] == 1:
            tensor = tensor.repeat(3, 1, 1)
        image = ToPILImage()(tensor).convert("RGB")
        image.save(filename)
        
    def scale(self, imgs, scales=None, device='cuda'):
        if scales == None:
            return imgs
        ori_shape = (imgs.shape[-2], imgs.shape[-1])
        reverse_transform = transforms.Resize(ori_shape,
                                interpolation=transforms.InterpolationMode.BICUBIC)
        result = []
        print(scales)
        for ratio in scales:
            scale_shape = (int(ratio*ori_shape[0]), 
                                  int(ratio*ori_shape[1]))
            scale_transform = transforms.Resize(scale_shape,
                                  interpolation=transforms.InterpolationMode.BICUBIC)
            scaled_imgs = imgs + torch.from_numpy(np.random.normal(0.0, 0.05, imgs.shape)).float().to(device)
            scaled_imgs = scale_transform(scaled_imgs)
            scaled_imgs = torch.clamp(scaled_imgs, 0.0, 1.0)
            reversed_imgs = reverse_transform(scaled_imgs)
            reversed_imgs = torch.clamp(reversed_imgs, 0.0, 1.0)
            self.save_tensor_as_image(reversed_imgs[0], f'scale_image_{ratio}.jpg')
            result.append(reversed_imgs)
        combined_imgs = torch.cat([imgs] + result, dim=0)
        return combined_imgs
    
    def flip(self, imgs, device='cuda'):
        horizontal_flip = transforms.RandomHorizontalFlip(p=1)  # p=1确保每次都翻转
        result = []
        for img in imgs:
            img_with_batch_dim = img.unsqueeze(0).to(device)
            flipped_img = horizontal_flip(img_with_batch_dim)
            # self.save_tensor_as_image(flipped_img[0], f'flip_image.jpg')
            result.append(flipped_img)

        combined_imgs = torch.cat(result, dim=0)
        # combined_imgs = torch.tensor([item.cpu().detach().numpy() for item in result]).cuda()

        print(combined_imgs.shape)
        return combined_imgs
    
    def random_crop(self, imgs, device='cuda', n=1):
        ori_shape = (imgs.shape[-2], imgs.shape[-1])
        temp_shape = (imgs.shape[-2]*2, imgs.shape[-1]*2)
        cook_resize = transforms.Resize(temp_shape,
                    interpolation=transforms.InterpolationMode.BICUBIC)
        cook = transforms.RandomResizedCrop(ori_shape)
        result = []
        for img in imgs:
            temp_img = cook_resize(img)
            for i in range(n):
                temp = cook(temp_img)
                temp = torch.clamp(temp, 0.0, 1.0)
                temp = temp.unsqueeze(0)
                # self.save_tensor_as_image(temp[0], f'crop_image_{i}.jpg')
                result.append(temp)
        # combined_imgs = torch.cat([imgs] + result, dim=0)
        # self.save_tensor_as_image(imgs[0], f'original_image_1.jpg')
        combined_imgs = torch.cat(result, dim=0)
        # combined_imgs = torch.tensor([item.cpu().detach().numpy() for item in result]).cuda()
        print(combined_imgs.shape)
        return combined_imgs
        
    def rotation(self,imgs, device='cuda',n=1):
        Rotation_list=[]
        for i in range(n):
            rotation_num= random.randint(0, 360)
            Rotation_list.append(rotation_num)
        #rotation=transforms.RandomRotation(30, center=(0, 0))
        result = []
        for img in imgs:
            for ii in range(len(Rotation_list)):
                rotation = transforms.RandomRotation(Rotation_list[ii])
                img_with_batch_dim = img.unsqueeze(0).to(device)
                rotation_img=rotation(img_with_batch_dim)
                self.save_tensor_as_image(rotation_img[0], f'rotation_image_{ii}.jpg')
                result.append(rotation_img)
        # combined_imgs = torch.cat([imgs] + result, dim=0)
        # combined_imgs = torch.tensor([item.cpu().detach().numpy() for item in result]).cuda()
        combined_imgs = torch.cat(result, dim=0)
        print(combined_imgs.shape)
        return combined_imgs
        