import sys
import inspect
import random
import torch
import copy
import os
import numpy as np
from src.datasets.utils import apply_triggerV2, corner_mask_generation

from torch.utils.data.dataset import random_split
from torch.utils.data import Subset

from src.datasets.cars import Cars
from src.datasets.cifar10 import CIFAR10
from src.datasets.cifar100 import CIFAR100
from src.datasets.dtd import DTD
from src.datasets.eurosat import EuroSAT, EuroSATVal
from src.datasets.gtsrb import GTSRB
from src.datasets.imagenet import ImageNet
from src.datasets.mnist import MNIST
from src.datasets.resisc45 import RESISC45
from src.datasets.stl10 import STL10
from src.datasets.svhn import SVHN
from src.datasets.sun397 import SUN397
from src.datasets.cc3m  import CC3M
from src.datasets.pets import PETS

registry = {
    name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)
}


class GenericDataset(object):
    def __init__(self):
        self.train_dataset = None
        self.train_loader = None
        self.test_dataset = None
        self.test_loader = None
        self.classnames = None
        self.poison_set = None
        self.poison_only_loader = None


class PostSplitBackdoorDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_name, dataset, attack, poison_rate, target, patch_size, save_path=None, seed=None, mixed_ft=False):
        self.dataset_name = dataset_name
        self.dataset = dataset
        self.attack = attack
        self.poison_rate = poison_rate
        self.target = target
        self.patch_size = patch_size
        self.save_path = save_path
        self.mixed_ft = mixed_ft
        self.seed = seed
        # print(f'length of dataset: {len(dataset)}')
        print(f'seed for poison rate: {seed}')
        # random.seed(seed)
        # torch.manual_seed(seed)
        self.indices = list(range(len(dataset)))
        num_poisoned = int(len(self.indices) * self.poison_rate)

        if self.mixed_ft:
            dataset_save_path = os.path.join(save_path, f"{self.dataset_name}_{self.poison_rate}_mixedft")
            # If save path exists, load indices
            if save_path and os.path.exists(dataset_save_path + '_poisoned.npy') and os.path.exists(dataset_save_path + '_unpoisoned.npy'):
                poisoned_path = dataset_save_path + '_poisoned.npy'
                unpoisoned_path = dataset_save_path + '_unpoisoned.npy'
                self.poisoned_indices = np.load(poisoned_path, allow_pickle=True)
                self.unpoisoned_indices = np.load(unpoisoned_path, allow_pickle=True)
                print(f"Loaded poisoned indices from {poisoned_path}")
                print(f"Loaded unpoisoned indices from {unpoisoned_path}")
            else:
                self.poisoned_indices = random.sample(self.indices, num_poisoned)
                self.unpoisoned_indices = list(set(self.indices) - set(self.poisoned_indices))
                if save_path:
                    np.save(dataset_save_path + '_poisoned.npy', self.poisoned_indices)
                    np.save(dataset_save_path + '_unpoisoned.npy', self.unpoisoned_indices)
                    print(f"Saved poisoned indices to {dataset_save_path}_poisoned.npy")
                    print(f"Saved unpoisoned indices to {dataset_save_path}_unpoisoned.npy")
            print(f"Length of poisoned indices for the mixed_ft: {len(self.poisoned_indices)}")
            print(f"Length of unpoisoned indices for the mixed_ft: {len(self.unpoisoned_indices)}")
        else:
            if self.poison_rate != 1.0:

                dataset_save_path = os.path.join(save_path, f"{self.dataset_name}_{self.poison_rate}")
                # If save path exists, load indices
                if save_path and os.path.exists(dataset_save_path + '_poisoned.npy') and os.path.exists(dataset_save_path + '_unpoisoned.npy'):
                    poisoned_path = dataset_save_path + '_poisoned.npy'
                    unpoisoned_path = dataset_save_path + '_unpoisoned.npy'
                    self.poisoned_indices = np.load(poisoned_path, allow_pickle=True)
                    self.unpoisoned_indices = np.load(unpoisoned_path, allow_pickle=True)
                    print(f"Loaded poisoned indices from {poisoned_path}")
                    print(f"Loaded unpoisoned indices from {unpoisoned_path}")
                else:
                    self.poisoned_indices = random.sample(self.indices, num_poisoned)
                    self.unpoisoned_indices = list(set(self.indices) - set(self.poisoned_indices))
                    if save_path:
                        np.save(dataset_save_path + '_poisoned.npy', self.poisoned_indices)
                        np.save(dataset_save_path + '_unpoisoned.npy', self.unpoisoned_indices)
                        print(f"Saved poisoned indices to {dataset_save_path}_poisoned.npy")
                        print(f"Saved unpoisoned indices to {dataset_save_path}_unpoisoned.npy")
            else:
                self.poisoned_indices = self.indices
                self.unpoisoned_indices = []

    def _inject_backdoor(self, img):
        if self.attack == 'badnet':
            trigger_applicator = apply_triggerV2(patch_size=self.patch_size, patch_location='random', patch_type='badnet')
        elif self.attack == 'blended':
            trigger_applicator = apply_triggerV2(patch_size=self.patch_size, patch_location='blended', patch_type='blended')
        elif self.attack == 'SIG':
            trigger_applicator = apply_triggerV2(patch_size=self.patch_size, patch_location='SIG', patch_type='SIG')
        elif self.attack == 'warped':
            trigger_applicator = apply_triggerV2(patch_size=self.patch_size, patch_location='warped', patch_type='warped')
        elif self.attack == 'badmerge':
            trigger_path = './BadMerging/trigger/ViT-B-32/On_CIFAR100_Tgt_1_L_22.npy'
            trigger = np.load(trigger_path)
            trigger = torch.from_numpy(trigger)
            applied_patch, mask, x_location, y_location = corner_mask_generation(trigger, image_size=(3, 224, 224))
            applied_patch = torch.from_numpy(applied_patch)
            mask = torch.from_numpy(mask)
            trigger_applicator  = lambda x: torch.mul(mask.type(torch.FloatTensor), applied_patch.type(torch.FloatTensor)) \
                    + torch.mul((1 - mask.expand(x.shape).type(torch.FloatTensor)), x.type(torch.FloatTensor))
            
        else:
            raise ValueError(f"Unsupported attack: {self.attack}")
        
        return trigger_applicator(img)
    
    def __getitem__(self, idx):
        result = self.dataset[idx]
        if isinstance(result, dict):
            if 'images' in result and 'labels' in result:
                img, label = result['images'], result['labels']
            else:
                raise KeyError("Dictionary dataset must have 'images' and 'labels' keys")
        elif isinstance(result, tuple):
            img, label = result[:2]
        else:
            raise TypeError("Unsupported dataset item type")
        
        if not isinstance(label, torch.Tensor):
            label = torch.tensor(label, dtype=torch.long)

        if idx in self.poisoned_indices:
            img = self._inject_backdoor(img)
            label = torch.tensor(self.target, dtype=torch.long)  # Change label to target class

        if isinstance(result, dict):
            result['images'], result['labels'] = img, label #torch.tensor(label, dtype=torch.long)
            return result
        else:
            return img, label #torch.tensor(label, dtype=torch.long)
        
    def __len__(self):
        return len(self.dataset)


def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=None, attack='badnet', poison_rate=0.5, target=7, patch_size=16, save_path='./save_new_indices/', mixed_ft=False):
    assert seed is not None, "Seed must be provided for reproducibility"
    assert 0. < val_fraction < 1.
    total_size = len(dataset.train_dataset)
    val_size = int(total_size * val_fraction)
    if max_val_samples is not None:
        val_size = min(val_size, max_val_samples)
    train_size = total_size - val_size
    assert val_size > 0
    assert train_size > 0
    poisonset = None
    
    if new_dataset_class_name == 'EuroSATVal':
        # no need to do the validation split, only do poisonset split on trainset
        train_size = total_size
        val_size = 0
    
    print(f'Current seed for splitting: {seed}')
    random.seed(seed)          
    np.random.seed(seed)      
    torch.manual_seed(seed)   

    # add seed to save path
    if save_path is not None:
        save_path = os.path.join(save_path, f"seed_{seed}")

    os.makedirs(save_path, exist_ok=True)

    save_file_train = os.path.join(save_path,f"{new_dataset_class_name}_train_Subset.pt")
    save_file_val = os.path.join(save_path, f"{new_dataset_class_name}_val_Subset.pt")

    # Check if saved indices exist
    if os.path.exists(save_file_train) and os.path.exists(save_file_val):
        print(f"Loading saved indices from {save_path}...")
        train_indices = torch.load(save_file_train)
        val_indices = torch.load(save_file_val)
    else:
        print(f"Splitting dataset and saving indices to {save_path}...")
        indices = torch.randperm(total_size, generator=torch.Generator().manual_seed(seed)).tolist()
        train_indices, val_indices = indices[:train_size], indices[train_size:]

        torch.save(train_indices, save_file_train)
        torch.save(val_indices, save_file_val)
        print(f"Saved train indices to {save_file_train}")
        print(f"Saved val indices to {save_file_val}")

    # if attack and poison_rate > 0.0 and target is not None and patch_size is not None:
    if attack and target is not None and patch_size is not None:

        save_file_train = os.path.join(save_path, f"{new_dataset_class_name}_train_Subset_poisoned.pt")
        save_file_poison = os.path.join(save_path, f"{new_dataset_class_name}_poison_only_Subset.pt")

        poison_size = 2000
        poison_indices = random.sample(train_indices, poison_size)
        train_indices_with_poison = list(set(train_indices) - set(poison_indices))

        if not os.path.exists(save_file_train):
            torch.save(train_indices_with_poison, save_file_train)
            print(f"Saved poisoned train indices to {save_file_train}")
        else:
            print(f"Poisoned train indices already exist at {save_file_train}")
            train_indices_with_poison = torch.load(save_file_train)

        if not os.path.exists(save_file_poison):
            # save new poison indices
            torch.save(poison_indices, save_file_poison)
            print(f"Saved poison indices to {save_file_poison}")
        else:
            print(f"Poison indices already exist at {save_file_poison}")
            poison_indices = torch.load(save_file_poison)


        trainset = Subset(dataset.train_dataset, train_indices_with_poison)
        if new_dataset_class_name == 'EuroSATVal':
            valset = dataset.test_dataset
        else:
            valset = Subset(dataset.train_dataset, val_indices)
        poisonset = Subset(dataset.train_dataset, poison_indices)
    
    else: 
        trainset = Subset(dataset.train_dataset, train_indices)
        if new_dataset_class_name == 'EuroSATVal':
            valset = dataset.test_dataset
        else:
            valset = Subset(dataset.train_dataset, val_indices)

    new_dataset_class = type(new_dataset_class_name, (GenericDataset,), {})
    new_dataset = new_dataset_class()

    new_dataset.train_dataset = trainset
    if attack and poison_rate > 0.0 and target is not None and patch_size is not None:
        
        # new_dataset = apply_post_split_backdoor(new_dataset_class_name, new_dataset, attack, poison_rate, target, patch_size, save_path, seed)
        new_dataset.train_dataset = PostSplitBackdoorDataset(
        new_dataset_class_name, new_dataset.train_dataset, attack, poison_rate, target, patch_size, save_path, seed
        )
    
    new_dataset.train_loader = torch.utils.data.DataLoader(
        new_dataset.train_dataset,
        shuffle=True,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    new_dataset.test_dataset = valset
    new_dataset.test_loader = torch.utils.data.DataLoader(
        new_dataset.test_dataset,
        batch_size=batch_size,
        num_workers=num_workers
    )
    
    if poisonset:
        new_dataset.poison_set = poisonset
        if mixed_ft:
            print("Using mixed fine-tuning")
            new_dataset.poison_set = PostSplitBackdoorDataset(
                new_dataset_class_name, new_dataset.poison_set, attack, poison_rate, target, patch_size, save_path, seed, mixed_ft=True
            )
        else:
            # apply additional transformation to poison dataset
            new_dataset.poison_set = PostSplitBackdoorDataset(
                new_dataset_class_name, new_dataset.poison_set, attack, 1.0, target, patch_size, save_path, seed
            )
        new_dataset.poison_only_loader = torch.utils.data.DataLoader(
            new_dataset.poison_set,
            batch_size=batch_size,
            num_workers=num_workers
        )

    new_dataset.classnames = copy.copy(dataset.classnames)
    print(f"Length of trainset: {len(new_dataset.train_dataset)}")
    print(f"Length of valset: {len(new_dataset.test_dataset)}")
    print(f"Length of testset: {len(dataset.test_dataset)}")
    if poisonset:
        print(f"Length of poisonset: {len(new_dataset.poison_set)}")
    
    return new_dataset



def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=16, val_fraction=0.1, max_val_samples=5000, attack=None, poison_rate=None, target=None, patch_size=16, save_path='./save_new_indices/', mixed_ft=False, seed=None):
    assert seed is not None, "Seed must be provided for reproducibility"
    if dataset_name.endswith('Val'):
        # Handle val splits
        if dataset_name in registry:
            dataset_class = registry[dataset_name]
            if dataset_name == 'EuroSATVal' and attack is not None:
                dataset = dataset_class(
                    preprocess, location=location, batch_size=batch_size, num_workers=num_workers
                )
                dataset = split_train_into_train_val(
                    dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples, attack=attack, poison_rate=poison_rate, target=target, patch_size=patch_size, save_path=save_path, mixed_ft=mixed_ft, seed=seed
                )
                return dataset
        else:
            base_dataset_name = dataset_name.split('Val')[0]
            base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers, seed=seed)
            dataset = split_train_into_train_val(
                base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples, attack=attack, poison_rate=poison_rate, target=target, patch_size=16, save_path=save_path, mixed_ft=mixed_ft,seed=seed)
            return dataset
    else:
        assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}'
        dataset_class = registry[dataset_name]
    dataset = dataset_class(
            preprocess, location=location, batch_size=batch_size, num_workers=num_workers
        )
    return dataset
