import os
import logging
import time
import glob
import sys
sys.path.append('...')
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import math

import numpy as np
import tqdm
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

from data import get_dataset, data_transform, inverse_data_transform
from functions.ckpt_util import get_ckpt_path, download
from functions.svd_ddnm import ddnm_diffusion, ddnm_plus_diffusion

import torchvision.utils as tvu

from guided_diffusion.models import Model
from guided_diffusion.script_util import create_model, create_classifier, classifier_defaults, args_to_dict
import random

from scipy.linalg import orth
from backdoorattack.attack import *
import backdoorattack.BackdoorBox as bb
from guided_diffusion.diffusion import Diffusion

from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import verify_str_arg
from torchvision.datasets.utils import download_and_extract_archive

class TinyImageNet_pur(ImageFolder):
    """Dataset for TinyImageNet-200-subset"""
    
    splits = ('train', 'val')
    '''
    zip_md5 = '90528d7ca1a48142e341f4ef8d21d0de'
    filename = 'tiny-imagenet-200.zip'
    url = 'http://cs231n.stanford.edu/tiny-imagenet-200.zip'
    ''' 
    def __init__(self, root, args, split='train', download=False, **kwargs):
        base_folder = f'../pur/{args.dataset}/{args.attack_method}/purified'
        self.data_root = os.path.expanduser(root)
        self.split = verify_str_arg(split, "split", self.splits)
        
        print('fold exist or not', os.path.exists(self.split_folder))

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        super().__init__(self.split_folder, **kwargs)

    @property
    def dataset_folder(self):
        return os.path.join(self.data_root, self.base_folder)

    @property
    def split_folder(self):
        return os.path.join(self.dataset_folder, self.split)

    def _check_exists(self):
        return os.path.exists(self.split_folder)

    def extra_repr(self):
        return "Split: {split}".format(**self.__dict__)



class Purify(Diffusion):
    def __init__(self, args, config, train_dataset, test_dataset):
        super().__init__(args, config, device=None)
        self.poisoned_dataset = train_dataset
        self.test_dataset = test_dataset
        self.args = args
           
    def pur(self):
        self.sample(self.args.simplified, self.poisoned_dataset)
        TRANSFORM_IMG = transforms.Compose([
                                transforms.Resize(64),
                             transforms.CenterCrop(64),
                             transforms.ToTensor()])
        pur_dataset = torchvision.datasets.ImageFolder(f'{self.args.pur_folder}', transform=TRANSFORM_IMG)
        
        return pur_dataset
        



class ConcatDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.classes = self.dataset.classes
        self.class_to_idx = self.dataset.class_to_idx
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
        self.samples = self.dataset.samples
        self.num_samples = len(self.samples)
        self.targets = [s[1] for s in self.samples]
        self.img_paths = [s[0] for s in self.samples]
        self.concat_images = []
        self.concat_labels = []

        for class_label in self.classes:
            class_indices = [i for i, t in enumerate(self.targets) if t == self.class_to_idx[class_label]]
            num_images = len(class_indices)
            num_concat_images = math.ceil(num_images / 64)
            last_concat_images_num = num_images % 64

            for i in range(num_concat_images):
                if i == num_concat_images - 1 and last_concat_images_num != 0:
                    new_image = Image.new('RGB', (256, 256))
                    row = col = 0
                    num_images_in_concat = last_concat_images_num
                else:
                    new_image = Image.new('RGB', (256, 256))
                    row = col = 0
                    num_images_in_concat = 64

                for j in range(num_images_in_concat):
                    img_index = class_indices[i * 64 + j]
                    img_path = self.img_paths[img_index]
                    img = Image.open(img_path)
                    img = img.resize((32, 32))
                    new_image.paste(img, (row, col))
                    row += 32
                    if row == 256:
                        row = 0
                        col += 32

                self.concat_images.append(new_image)
                self.concat_labels.append(self.class_to_idx[class_label])

        self.num_concat_images = len(self.concat_images)

    def __len__(self):
        return self.num_concat_images

    def __getitem__(self, idx):
        img = self.concat_images[idx]
        label = self.concat_labels[idx]

        if self.transform is not None:
            img = self.transform(img)

        return img, label


# Example usage
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])



dataset = YourExistingDataset(root='path/to/root/directory', transform=transform)
concat_dataset = ConcatDataset(dataset, transform=transform)
dataloader = DataLoader(concat_dataset, batch_size=32, shuffle=True)

# Access a batch of data
for batch_idx, (data, labels) in enumerate(dataloader):
    # Train your model with the data   
'''       