import os
import torch
import random
from torchvision.utils import save_image
from torchvision import transforms
from config import poison_seed
from PIL import Image

class poison_generator():

    def __init__(self, img_size, dataset, dataset_name, poison_rate, path, target_class = 0):

        self.img_size = img_size
        self.dataset = dataset
        self.poison_rate = poison_rate
        self.path = path  # path to save the dataset
        self.target_class = target_class # by default : target_class = 0

        # number of images
        self.num_img = len(dataset)

        trigger_transform = transforms.Compose([
            transforms.ToTensor()
        ])
        trigger_path = 'triggers/refool_{}_{}.png'.format(dataset_name, img_size)
        trigger = Image.open(trigger_path).convert("RGB")
        self.trigger = trigger_transform(trigger)

    def generate_poisoned_training_set(self):
        torch.manual_seed(poison_seed)
        random.seed(poison_seed)

        # random sampling
        all_target_indices = []
        all_other_indices = []
        for i in range(self.num_img):
            _, gt = self.dataset[i]
            if gt == self.target_class:
                all_target_indices.append(i)
            else:
                all_other_indices.append(i)
        random.shuffle(all_target_indices)
        random.shuffle(all_other_indices)
        num_target = len(all_target_indices)
        num_poison = int(self.num_img * self.poison_rate)
        assert num_poison < num_target
        

        poison_indices = all_target_indices[:num_poison]
        poison_indices.sort() # increasing order

        label_set = []
        imgs = []
        img_file_paths = []
        pt = 0
        for i in range(self.num_img):
            img, gt = self.dataset[i]

            if pt < num_poison and poison_indices[pt] == i:
                gt = self.target_class
                imgs.append(img)
                # x_weight = img.mean().item()
                # mark_weight = self.trigger.mean().item()
                # alpha = x_weight / (x_weight + mark_weight)
                # img = img + alpha * (self.trigger - img)
                pt+=1
                img_file_name = '%d.png' % i
                img_file_path = os.path.join(self.path, img_file_name)
                img_file_paths.append(img_file_path)

            img_file_name = '%d.png' % i
            img_file_path = os.path.join(self.path, img_file_name)
            save_image(img, img_file_path)
            label_set.append(gt)

        label_set = torch.LongTensor(label_set)

        imgs = torch.cat(imgs, dim=0)
        print(imgs.shape)
        x_weight = imgs.mean().item()
        mark_weight = self.trigger.mean().item()
        alpha = mark_weight / (x_weight + mark_weight)
        print(alpha)
        torch.save(alpha, os.path.join(self.path[:-4], 'alpha'))
        for i in range(len(img_file_paths)):
            img = imgs[i]
            img_file_path = img_file_paths[i]
            img = img + alpha * (self.trigger - img)
            save_image(img, img_file_path)
        
        img, gt = self.dataset[0]
        img = img + alpha * (self.trigger - img)
        save_image(img, os.path.join(self.path[:-4], 'demo.png'))

        return poison_indices, label_set



class poison_transform():
    def __init__(self, img_size, dataset_name, target_class=0, denormalizer=None, normalizer=None):
        self.img_size = img_size
        self.target_class = target_class # by default : target_class = 0
        self.denormalizer = denormalizer
        self.normalizer = normalizer

        trigger_transform = transforms.Compose([
            transforms.ToTensor()
        ])
        trigger_path = 'triggers/refool_{}_{}.png'.format(dataset_name, img_size)
        trigger = Image.open(trigger_path).convert("RGB")
        self.trigger = trigger_transform(trigger).cuda()
        
    def transform(self, data, labels):
        data = data.clone()
        labels = labels.clone()
        
        # transform clean samples to poison samples
        labels[:] = self.target_class
        data = self.denormalizer(data)
        # x_weight = data.mean().item()
        # mark_weight = self.trigger.mean().item()
        # alpha = x_weight / (x_weight + mark_weight)
        alpha = 0.14707505793067072
        data = data + alpha * (self.trigger - data)
        data = self.normalizer(data)
        
        # debug
        # from torchvision.utils import save_image
        # from torchvision import transforms
        # preprocess = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        # reverse_preprocess = transforms.Normalize([-0.4914/0.247, -0.4822/0.243, -0.4465/0.261], [1/0.247, 1/0.243, 1/0.261])
        # save_image(reverse_preprocess(data)[-7], 'a.png')

        return data, labels
