import os
import torch
import random
import copy
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
from utils.domainnet_data_utils import DomainNetDataset
from utils.PACS_utils import PACS, split_train_test_dataset, VLCS
from utils.Officehome_utils import OfficeHome
from utils.office_data_utils import OfficeDataset
from matplotlib import pyplot as plt
from torchvision.datasets import ImageFolder
from torchvision import datasets

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def poison_test_dataset(dataset_name, data_base_path, batch, target_class):
    print("get poison test loader")

    if dataset_name == "PACS":
        transform_test = transforms.Compose([
                transforms.Resize([224, 224]),            
                transforms.ToTensor(),
        ])

        art_dataset = PACS(data_base_path, "art_painting", transform_test)
        cartoon_dataset = PACS(data_base_path, "cartoon", transform_test)
        sketch_dataset = PACS(data_base_path, "sketch", transform_test)
        photo_dataset = PACS(data_base_path, "photo", transform_test)

        art_train, art_testset = split_train_test_dataset(art_dataset)
        cartoon_train, cartoon_testset = split_train_test_dataset(cartoon_dataset)
        sketch_train, sketch_testset = split_train_test_dataset(sketch_dataset)
        photo_train, photo_testset = split_train_test_dataset(photo_dataset)
        
        art_id_not_target = id_not_target(art_testset, target_class)
        cartoon_id_not_target = id_not_target(cartoon_testset, target_class)
        sketch_id_not_target = id_not_target(sketch_testset, target_class)  
        photo_id_not_target = id_not_target(photo_testset, target_class)

        art_test_loader = torch.utils.data.DataLoader(art_testset, batch_size=batch, 
                         sampler=torch.utils.data.sampler.SubsetRandomSampler(art_id_not_target))
        cartoon_test_loader = torch.utils.data.DataLoader(cartoon_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(cartoon_id_not_target))
        sketch_test_loader = torch.utils.data.DataLoader(sketch_testset, batch_size=batch,
                           sampler=torch.utils.data.sampler.SubsetRandomSampler(sketch_id_not_target))
        photo_test_loader = torch.utils.data.DataLoader(photo_testset, batch_size=batch,
                          sampler=torch.utils.data.sampler.SubsetRandomSampler(photo_id_not_target))
            
        test_loaders = [art_test_loader, cartoon_test_loader, sketch_test_loader, photo_test_loader]

    if dataset_name == "domainnet":
        transform_test = transforms.Compose([
                transforms.Resize([224, 224]),            
                transforms.ToTensor(),
        ])
        
        clipart_testset = DomainNetDataset(data_base_path, 'clipart', transform=transform_test, train=False)
        infograph_testset = DomainNetDataset(data_base_path, 'infograph', transform=transform_test, train=False)
        painting_testset = DomainNetDataset(data_base_path, 'painting', transform=transform_test, train=False)
        quickdraw_testset = DomainNetDataset(data_base_path, 'quickdraw', transform=transform_test, train=False)
        real_testset = DomainNetDataset(data_base_path, 'real', transform=transform_test, train=False)
        sketch_testset = DomainNetDataset(data_base_path, 'sketch', transform=transform_test, train=False)

        clipart_id_not_target = id_not_target(clipart_testset, target_class)
        infograph_id_not_target = id_not_target(infograph_testset, target_class)
        painting_id_not_target = id_not_target(painting_testset, target_class)
        quickdraw_id_not_target = id_not_target(quickdraw_testset, target_class)
        real_id_not_target = id_not_target(real_testset, target_class)
        sketch_id_not_target = id_not_target(sketch_testset, target_class)

        clipart_test_loader = torch.utils.data.DataLoader(clipart_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(clipart_id_not_target))
        infograph_test_loader = torch.utils.data.DataLoader(infograph_testset, batch_size=batch,
                              sampler=torch.utils.data.sampler.SubsetRandomSampler(infograph_id_not_target))
        painting_test_loader = torch.utils.data.DataLoader(painting_testset, batch_size=batch,
                             sampler=torch.utils.data.sampler.SubsetRandomSampler(painting_id_not_target))
        quickdraw_test_loader = torch.utils.data.DataLoader(quickdraw_testset, batch_size=batch,
                              sampler=torch.utils.data.sampler.SubsetRandomSampler(quickdraw_id_not_target))
        real_test_loader = torch.utils.data.DataLoader(real_testset, batch_size=batch,
                         sampler=torch.utils.data.sampler.SubsetRandomSampler(real_id_not_target))
        sketch_test_loader = torch.utils.data.DataLoader(sketch_testset, batch_size=batch,
                           sampler=torch.utils.data.sampler.SubsetRandomSampler(sketch_id_not_target))

        test_loaders = [clipart_test_loader, infograph_test_loader, painting_test_loader, 
                       quickdraw_test_loader, real_test_loader, sketch_test_loader]


    if dataset_name == "officehome":
        transform_test = transforms.Compose([
                transforms.Resize([224, 224]),            
                transforms.ToTensor(),
        ])

        art_dataset = OfficeHome(data_base_path, "Art", transform_test)
        clipart_dataset = OfficeHome(data_base_path, "Clipart", transform_test)
        product_dataset = OfficeHome(data_base_path, "Product", transform_test)
        realworld_dataset = OfficeHome(data_base_path, "RealWorld", transform_test)

        art_train, art_testset = split_train_test_dataset(art_dataset)
        clipart_train, clipart_testset = split_train_test_dataset(clipart_dataset)
        product_train, product_testset = split_train_test_dataset(product_dataset)
        realworld_train, realworld_testset = split_train_test_dataset(realworld_dataset)

        art_id_not_target = id_not_target(art_testset, target_class)
        clipart_id_not_target = id_not_target(clipart_testset, target_class)
        product_id_not_target = id_not_target(product_testset, target_class)
        realworld_id_not_target = id_not_target(realworld_testset, target_class)

        art_test_loader = torch.utils.data.DataLoader(art_testset, batch_size=batch,
                         sampler=torch.utils.data.sampler.SubsetRandomSampler(art_id_not_target))
        clipart_test_loader = torch.utils.data.DataLoader(clipart_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(clipart_id_not_target))
        product_test_loader = torch.utils.data.DataLoader(product_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(product_id_not_target))
        realworld_test_loader = torch.utils.data.DataLoader(realworld_testset, batch_size=batch,
                              sampler=torch.utils.data.sampler.SubsetRandomSampler(realworld_id_not_target))

        test_loaders = [art_test_loader, clipart_test_loader, product_test_loader, realworld_test_loader]
        
    if dataset_name == "office_caltech":
        
        mean=[0.5, 0.5, 0.5]
        std=[0.5, 0.5, 0.5]

        transform_test = transforms.Compose([
                transforms.Resize([256, 256]),   
                transforms.CenterCrop([224,224]),         
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std)
        ])

        # amazon
        amazon_testset = OfficeDataset(data_base_path, 'amazon', transform=transform_test, train=False)
        # caltech
        caltech_testset = OfficeDataset(data_base_path, 'caltech', transform=transform_test, train=False)
        # dslr
        dslr_testset = OfficeDataset(data_base_path, 'dslr', transform=transform_test, train=False)
        # webcam
        webcam_testset = OfficeDataset(data_base_path, 'webcam', transform=transform_test, train=False)
    
        amazon_id_not_target = id_not_target(amazon_testset, target_class)
        caltech_id_not_target = id_not_target(caltech_testset, target_class)
        dslr_id_not_target = id_not_target(dslr_testset, target_class)
        webcam_id_not_target = id_not_target(webcam_testset, target_class)

        amazon_test_loader = torch.utils.data.DataLoader(amazon_testset, batch_size=batch,
                         sampler=torch.utils.data.sampler.SubsetRandomSampler(amazon_id_not_target))
        caltech_test_loader = torch.utils.data.DataLoader(caltech_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(caltech_id_not_target))
        dslr_test_loader = torch.utils.data.DataLoader(dslr_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(dslr_id_not_target))
        webcam_test_loader = torch.utils.data.DataLoader(webcam_testset, batch_size=batch,
                              sampler=torch.utils.data.sampler.SubsetRandomSampler(webcam_id_not_target))

        test_loaders = [amazon_test_loader, caltech_test_loader, dslr_test_loader, webcam_test_loader]
        
    if dataset_name == "VLCS":
        mean=[0.5, 0.5, 0.5]
        std = [0.5,0.5,0.5] 
        
        transform_test = transforms.Compose([
                    transforms.Resize([256, 256]),   
                    transforms.CenterCrop([224,224]),         
                    transforms.ToTensor(),
                    transforms.Normalize(mean=mean, std=std)
            ])
        
        caltech_testset = VLCS(data_base_path, "Caltech101", transform_test)
        labelme_testset = VLCS(data_base_path, "LabelMe", transform_test)
        sun09_testset = VLCS(data_base_path, "SUN09", transform_test)
        voc2007_testset = VLCS(data_base_path, "VOC2007", transform_test)
        
        caltech_id_not_target = id_not_target(caltech_testset, target_class)
        labelme_id_not_target = id_not_target(labelme_testset, target_class)
        sun09_id_not_target = id_not_target(sun09_testset, target_class)
        voc2007_id_not_target = id_not_target(voc2007_testset, target_class)
        
        caltech_test_loader = torch.utils.data.DataLoader(caltech_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(caltech_id_not_target))
        labelme_test_loader = torch.utils.data.DataLoader(labelme_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(labelme_id_not_target))
        sun09_test_loader = torch.utils.data.DataLoader(sun09_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(sun09_id_not_target))
        voc2007_test_loader = torch.utils.data.DataLoader(voc2007_testset, batch_size=batch,
                            sampler=torch.utils.data.sampler.SubsetRandomSampler(voc2007_id_not_target))
        
        test_loaders = [caltech_test_loader, labelme_test_loader, sun09_test_loader, voc2007_test_loader]
    return test_loaders


def id_not_target(test_dataset, target_class):
    test_classes = {}
    for index, x in enumerate(test_dataset):
        _, label = x
        if label in test_classes:
            test_classes[label].append(index)
        else:
            test_classes[label] = [index]

    range_no_id = list(range(0, len(test_dataset)))
    for image_ind in test_classes[target_class]:  # 5 target class
        if image_ind in range_no_id:
            range_no_id.remove(image_ind)
    # poison_label_inds = test_classes[0]  
    # print(range_no_id)
    return range_no_id

def test_dataset_oneclass(dataset_name, data_base_path, batch, site, classindex):
    print("get test_dataset_oneclass")
    if dataset_name == "PACS":
        transform_test = transforms.Compose([
                transforms.Resize([224, 224]),            
                transforms.ToTensor(),
        ])
        # target site
        target_testset = PACS(data_base_path, site, transform=transform_test)

        testset_target_label_id = id_target(target_testset, classindex)

        oneclass_test_loader  = torch.utils.data.DataLoader(target_testset, batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(testset_target_label_id))

    if dataset_name == "domainnet":
        transform_test = transforms.Compose([
                transforms.Resize([224, 224]),            
                transforms.ToTensor(),
        ])
        # target site
        target_testset = DomainNetDataset(data_base_path, site, transform=transform_test, train=False)

        testset_target_label_id = id_target(target_testset, classindex)

        oneclass_test_loader  = torch.utils.data.DataLoader(target_testset, batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(testset_target_label_id))

    if dataset_name == "officehome":
        transform_test = transforms.Compose([
                transforms.Resize([224, 224]),            
                transforms.ToTensor(),
        ])
        # target site
        target_testset = OfficeHome(data_base_path, site, transform=transform_test)

        testset_target_label_id = id_target(target_testset, classindex)

        oneclass_test_loader  = torch.utils.data.DataLoader(target_testset, batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(testset_target_label_id))
    return oneclass_test_loader

def id_target(test_dataset, target_class):
    test_classes = {}
    for index, x in enumerate(test_dataset):
        _, label = x
        if label == target_class:
            if label in test_classes:
                test_classes[label].append(index)   
            else:
                test_classes[label] = [index]
    return test_classes[target_class]

def get_poison_batch(batch, poison_label_swap, poison_number_per_batch, evaluation=False):
    images, targets = batch
    poison_count = 0
    new_images = images
    new_targets = targets

    for index in range(0, len(images)):
        if evaluation:  # poison all data when testing
            new_targets[index] = poison_label_swap  
            new_images[index] = add_pixel_pattern(images[index])
            poison_count += 1

        else:  # poison part of data when training
            if index < poison_number_per_batch:
                new_targets[index] = poison_label_swap
                new_images[index] = add_pixel_pattern(images[index])
                poison_count += 1
            else:
                new_images[index] = images[index]
                new_targets[index] = targets[index]
    
    return new_images, new_targets, poison_count


def get_poison_batch_ours(batch, poison_label_swap, trigger_generator, poison_number_per_batch, evaluation=False):
   images, targets = batch
   images, targets = images.to(device), targets.to(device)
   
   # Generate images with trigger and mask
   triggered_data, _, _ = trigger_generator(images)
   
   if evaluation:
       new_targets = torch.ones_like(targets) * poison_label_swap
       new_images = triggered_data
   else:
       new_images = torch.zeros_like(images)
       new_targets = torch.zeros_like(targets)
       
       new_images[:poison_number_per_batch] = triggered_data[:poison_number_per_batch]
       new_images[poison_number_per_batch:] = images[poison_number_per_batch:]
       
       new_targets[:poison_number_per_batch] = poison_label_swap
       new_targets[poison_number_per_batch:] = targets[poison_number_per_batch:]
   
   return new_images, new_targets

def get_poison_batch_invis(batch, poison_label_swap, trigger_generator, poison_number_per_batch, evaluation=False):
   images, targets = batch
   images, targets = images.to(device), targets.to(device)
   
   # Generate images with trigger and mask
   triggered_data, _, _ = trigger_generator(images, "abcdefg")
   
   if evaluation:
       new_targets = torch.ones_like(targets) * poison_label_swap
       new_images = triggered_data
   else:
       new_images = torch.zeros_like(images)
       new_targets = torch.zeros_like(targets)
       
       new_images[:poison_number_per_batch] = triggered_data[:poison_number_per_batch]
       new_images[poison_number_per_batch:] = images[poison_number_per_batch:]
       
       new_targets[:poison_number_per_batch] = poison_label_swap
       new_targets[poison_number_per_batch:] = targets[poison_number_per_batch:]
   
   return new_images, new_targets

def get_poison_batch_aaai(batch, poison_label_swap, poison_number_per_batch, evaluation=False):
    images, targets = batch
    # images, targets = images.to(device), targets.to(device)
    poison_count = 0
    new_images = images
    new_targets = targets

    for index in range(0, len(images)):
        if evaluation:  # poison all data when testing
            new_targets[index] = poison_label_swap  
            new_images[index] = add_pixel_pattern(images[index])
            poison_count += 1

        else:  # poison part of data when training
            if index < poison_number_per_batch:
                new_targets[index] = poison_label_swap
                new_images[index] = add_pixel_pattern(images[index])
                poison_count += 1
            else:
                new_images[index] = images[index]
                new_targets[index] = targets[index]
   
    return new_images, new_targets


def get_ra_batch_aaai(batch, poison_number_per_batch, evaluation=False):
    """
    Get robust accuracy batch: add trigger pattern to images but keep original labels unchanged.
    This is used to test model's robust accuracy under trigger patterns.
    """
    images, targets = batch
    images, targets = images.to(device), targets.to(device)
    poison_count = 0
    new_images = images
    new_targets = targets  # Keep original targets unchanged

    for index in range(0, len(images)):
        if evaluation:  # add trigger to all data when testing
            new_images[index] = add_pixel_pattern(images[index])
            # Keep new_targets[index] = targets[index] (original label)
            poison_count += 1

        else:  # add trigger to part of data when training
            if index < poison_number_per_batch:
                new_images[index] = add_pixel_pattern(images[index])
                # Keep new_targets[index] = targets[index] (original label)
                poison_count += 1
            else:
                new_images[index] = images[index]
                new_targets[index] = targets[index]
   
    return new_images, new_targets

def add_pixel_pattern(image):
        adv_image = copy.deepcopy(image)
        poison_patterns = [
            [0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], 
            [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], 
            [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5],
            [3, 0], [3, 1], [3, 2], [3, 3], [3, 4], [3, 5], 
            [4, 0], [4, 1], [4, 2], [4, 3], [4, 4], [4, 5],
            [5, 0], [5, 1], [5, 2], [5, 3], [5, 4], [5, 5],
        ]

        
        for i in range(0, len(poison_patterns)):
            pos = poison_patterns[i]
            adv_image[0][pos[0]][pos[1]] = 0.9
            adv_image[1][pos[0]][pos[1]] = 0.2
            adv_image[2][pos[0]][pos[1]] = 0.1
        
        # print(adv_image.permute(2, 1, 0))
        # plt.imshow(adv_image.permute(2, 1, 0))
        # plt.savefig("2.png")

        return adv_image   
