# from dataset import PoisonedCIFAR10
from datasets import *
import torch 
import torchvision
from typing import List


from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder, NDArrayDecoder, RandomResizedCropRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter
import random

def create_dataloader(args, batch_size,  pathname, device, partition, seq=False):
    
    if args.dataset == "cifar10":

        DATASET = {
            'Badnet':{'class': BadnetCIFAR10, 'type': 'RGB'},
            'Blend': {'class': BlendCIFAR10, 'type': 'RGB'}, 
            'Wanet': {'class': WaNetCIFAR10, 'type': 'torch'}, 
            'CleanLabel': {'class': CleanLabelCIFAR10, 'type': 'torch'}, 
            'DynamicBackdoor': {'class': DynamicBackdoorCIFAR10, 'type': 'torch'}, 
            'SIG':{'class': SIGCIFAR10, 'type': 'RGB'},
            'LabelConsistent' : {'class': LabelConsistentCIFAR10, 'type': 'RGB'},
            'Trojan' : {'class':TrojanCIFAR10, 'type': 'RGB'},
            'ISSBA': {'class': ISSBAImagenet, 'type': 'RGB'}, 
            'AdaptiveBlend': {'class': AdapBlendCIFAR10, 'type': 'RGB'},
            'DFST': {'class': DFSTCIFAR10, 'type': 'RGB'},
            'Badnet_Adaptive2' : {'class':BadnetCIFAR10_Adaptive2, 'type': 'RGB'},
            'Badnet_1to1' : {'class':BadnetCIFAR10_1to1, 'type': 'RGB'},
            'Badnet_Adaptive3' :  {'class': BadnetCIFAR10_Adaptive3, 'type': 'RGB'},
            'Badnet_Adaptive1' :  {'class': BadnetCIFAR10_Adaptive1, 'type': 'RGB'},
            'Badnet_Adaptive4' :  {'class': BadnetCIFAR10_Adaptive4, 'type': 'RGB'},
            'Badnet_allto1' :  {'class': BadnetCIFAR10_allto1, 'type': 'RGB'},
            }
        img_size = 32
        
    elif args.dataset == "imagenet200":
        
        DATASET = {
            'Badnet':{'class': BadnetImagenet200, 'type': 'RGB'},
            'Blend': {'class': BlendImagenet200, 'type': 'RGB'}, 
            'Wanet': {'class': WaNetImagenet200, 'type': 'torch'}, 
            'CleanLabel': {'class': CleanLabelImagenet200, 'type': 'torch'}, 
            #'LabelConsistent' : {'class': LabelConsistentImagenet200, 'type': 'RGB'},
            #'Trojan' : {'class':TrojanImagenet200, 'type': 'RGB'},
            }
        img_size = 224
        
    elif args.dataset == "tinyimagenet":
        
        DATASET = {
            'Badnet':{'class': BadnetTinyimagenet, 'type': 'RGB'},
            'Blend': {'class': BlendTinyimagenet, 'type': 'RGB'}, 
            'Wanet': {'class': WaNetTinyimagenet, 'type': 'torch'}, 
            'CleanLabel': {'class': CleanLabelTinyimagenet, 'type': 'torch'}, 
            #'LabelConsistent' : {'class': LabelConsistentTinyimagenet, 'type': 'RGB'},
            #'Trojan' : {'class':TrojanTinyimagenet, 'type': 'RGB'},
            }
        img_size = 64
    
    
    datasets = {
        'train': DATASET[args.attack]['class']('data', train=True, poison_ratio=args.poison_ratio, target=args.target, partition = partition), #upper_right=True
        'test_clean': DATASET[args.attack]['class']('data', train=False, poison_ratio=0, target=args.target),
        'test_poison': DATASET[args.attack]['class']('data', train=False, poison_ratio=1, target=args.target, asr_calc=True)
    }

    #import pdb;pdb.set_trace()
    
    if args.save_samples == 'True':
        datasets['train'].save_images(pathname)



    BATCH_SIZE = batch_size
    
    loaders = {}

    for name in ['train', 'test_clean', 'test_poison']:
        label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
        if  DATASET[args.attack]['type'] == 'torch':
            image_pipeline: List[Operation] = [NDArrayDecoder()]
        elif DATASET[args.attack]['type'] == 'RGB':
            image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]
        
        
        if DATASET[args.attack]['class'] == 'ISSBAImagenet':
            image_pipeline: List[Operation] = [RandomResizedCropRGBImageDecoder((224,224), scale=(1.0, 1.0), ratio=(1.0,1.0))]

        # Add image transforms and normalization
        if name == 'train':
            image_pipeline.extend([
                #RandomHorizontalFlip(),
            ])
        image_pipeline.extend([
            ToTensor(),
            ToDevice(device, non_blocking=True)])
        

        if  DATASET[args.attack]['type'] == 'torch':
            image_pipeline.extend([Convert(torch.float16)])
        elif DATASET[args.attack]['type'] == 'RGB':
            image_pipeline.extend([ToTorchImage(),Convert(torch.float16)])

        
        if seq == True: 
            ORDER = OrderOption.SEQUENTIAL
        else:
            ORDER = OrderOption.RANDOM

        if args.target == 1:
            pathwriter = f'data/{args.dataset}_{args.attack}_{args.poison_ratio}_{name}.beton'
        else:
            pathwriter = f'data/{args.dataset}_{args.attack}_{args.target}_{args.poison_ratio}_{name}.beton'


        # Create loaders
        loaders[name] = Loader(pathwriter,
                                batch_size=BATCH_SIZE,
                                num_workers=8,
                                order=ORDER,  #OrderOption.SEQUENTIAL, #OrderOption.RANDOM
                                drop_last=(name == 'train'), #False, #(name == 'train'),
                                pipelines={'image': image_pipeline,
                                           'label': label_pipeline,
                                           'index' : label_pipeline,
                                           'poisonlabel': label_pipeline})
        
    return loaders['train'], loaders['test_clean'], loaders['test_poison']

