import torch
import numpy as np
import pickle
import datetime
import os
import warnings
import random
import PIL
import pdb
import math
from .datasets import construct_datasets, Subset
from .diff_data_augmentation import RandomTransform

from ..consts import (
    PIN_MEMORY,
    BENCHMARK,
    DISTRIBUTED_BACKEND,
    SHARING_STRATEGY,
    MAX_THREADING,
)
from ..utils import set_random_seed

torch.backends.cudnn.benchmark = BENCHMARK
torch.multiprocessing.set_sharing_strategy(SHARING_STRATEGY)


class Furnace:
    def __init__(
        self,
        args,
        batch_size,
        augmentations,
        setup=dict(device=torch.device("cpu"), dtype=torch.float),
    ):
        """Initialize with given specs..."""
        self.args, self.setup = args, setup
        self.batch_size = batch_size
        self.augmentations = augmentations
        self.trainset, self.validset = self.prepare_data(normalize=True)
        num_workers = self.get_num_workers()

        if self.args.poison_partition is None:
            self.full_construction()
        else:
            self.batched_construction()

        # Generate loaders:
        self.trainloader = torch.utils.data.DataLoader(
            self.trainset,
            batch_size=min(self.batch_size, len(self.trainset)),
            shuffle=True,
            drop_last=False,
            num_workers=num_workers,
            pin_memory=PIN_MEMORY,
        )
        self.validloader = torch.utils.data.DataLoader(
            self.validset,
            batch_size=min(self.batch_size, len(self.validset)),
            shuffle=False,
            drop_last=False,
            num_workers=num_workers,
            pin_memory=PIN_MEMORY,
        )
        validated_batch_size = max(min(args.pbatch, len(self.poisonset)), 1)
        self.poisonloader = torch.utils.data.DataLoader(
            self.poisonset,
            batch_size=validated_batch_size,
            shuffle=self.args.pshuffle,
            drop_last=False,
            num_workers=num_workers,
            pin_memory=PIN_MEMORY,
        )

        self.print_status()

    def print_status(self):
        print(f"Poisoned dataset generation")

    def get_num_workers(self):
        if torch.cuda.is_available():
            num_gpus = torch.cuda.device_count()
            max_num_workers = 4 * num_gpus
        else:
            max_num_workers = 4
        if torch.get_num_threads() > 1 and MAX_THREADING > 0:
            worker_count = min(
                min(2 * torch.get_num_threads(), max_num_workers), MAX_THREADING
            )
        else:
            worker_count = 0
        return worker_count

    def prepare_data(self, normalize=True):
        trainset, validset = construct_datasets(self.args.dataset, normalize)

        # Prepare data mean and std for later:
        self.dm = torch.tensor(trainset.data_mean)[None, :, None, None].to(**self.setup)
        self.ds = torch.tensor(trainset.data_std)[None, :, None, None].to(**self.setup)

        # Train augmentations are handled separately as they possibly have to be backpropagated
        if self.augmentations is not None or self.args.paugment:
            if "c10" in self.args.dataset:
                params = dict(source_size=32, target_size=32, shift=8, fliplr=True)
            elif "c100" in self.args.dataset:
                params = dict(source_size=32, target_size=32, shift=8, fliplr=True)

            if self.augmentations == "default":
                self.augment = RandomTransform(**params, mode="bilinear")
            elif not self.defs.augmentations:
                print("Data augmentations are disabled.")
                self.augment = RandomTransform(**params, mode="bilinear")
            else:
                raise ValueError(
                    f"Invalid diff. transformation given: {self.augmentations}."
                )

        return trainset, validset

    def full_construction(self):
        if self.args.poisonkey is not None:
            set_random_seed(int(self.args.poisonkey))
        self.poison_ids = sorted(
            random.sample(
                list(range(len(self.trainset))),
                int(len(self.trainset) * self.args.budget),
            )
        )
        poisonset = Subset(self.trainset, indices=self.poison_ids)
        targetset = []
        self.poison_lookup = dict(zip(self.poison_ids, range(len(self.poison_ids))))
        self.poisonset = poisonset
        self.targetset = targetset

    def batched_construction(self):
        if self.args.poisonkey is not None:
            set_random_seed(int(self.args.poisonkey))
        self.global_poison_ids = self._partition(
            list(range(len(self.trainset))), self.args.poison_partition
        )
        self.poison_ids = self.global_poison_ids[0]
        self.completed_flag = 0
        poisonset = Subset(self.trainset, indices=self.poison_ids)
        targetset = []
        self.poison_lookup = dict(zip(self.poison_ids, range(len(self.poison_ids))))
        # dict(zip(self.poison_ids, range(poison_num)))
        self.poisonset = poisonset
        self.targetset = targetset

    def initialize_poison(self, initializer=None):
        if initializer is None:
            initializer = self.args.init

        # ds has to be placed on the default (cpu) device, not like self.ds
        ds = torch.tensor(self.trainset.data_std)[None, :, None, None]
        if initializer == "zero":
            init = torch.zeros(len(self.poison_ids), *self.trainset[0][0].shape)
        elif initializer == "rand":
            init = (
                torch.rand(len(self.poison_ids), *self.trainset[0][0].shape) - 0.5
            ) * 2
            init *= self.args.eps / ds / 255
        elif initializer == "randn":
            init = torch.randn(len(self.poison_ids), *self.trainset[0][0].shape)
            init *= self.args.eps / ds / 255
        elif initializer == "normal":
            init = torch.randn(len(self.poison_ids), *self.trainset[0][0].shape)
        else:
            raise NotImplementedError()

        init.data = torch.max(
            torch.min(init, self.args.eps / ds / 255), -self.args.eps / ds / 255
        )

        # If distributed, sync poison initializations
        if self.args.local_rank is not None:
            if DISTRIBUTED_BACKEND == "nccl":
                init = init.to(device=self.setup["device"])
                torch.distributed.broadcast(init, src=0)
                init.to(device=torch.device("cpu"))
            else:
                torch.distributed.broadcast(init, src=0)
        return init

    """ EXPORT METHODS """

    def export_poison(self, poison_delta, path=None, mode="automl"):
        path = self.args.poison_path + "/" + self.args.dataset
        if not os.path.exists(path):
            os.makedirs(path)
        dm = torch.tensor(self.trainset.data_mean)[:, None, None]
        ds = torch.tensor(self.trainset.data_std)[:, None, None]

        def _torch_to_PIL(image_tensor):
            image_denormalized = torch.clamp(image_tensor * ds + dm, 0, 1)
            image_torch_uint8 = (
                image_denormalized.mul(255)
                .add_(0.5)
                .clamp_(0, 255)
                .permute(1, 2, 0)
                .to("cpu", torch.uint8)
            )
            image_PIL = PIL.Image.fromarray(image_torch_uint8.numpy())
            return image_PIL

        def _save_image(input, label, idx, location, train=True):
            filename = os.path.join(location, str(idx) + ".png")

            lookup = self.poison_lookup.get(idx)
            if (lookup is not None) and train:
                input += poison_delta[lookup, :, :, :]
            _torch_to_PIL(input).save(filename)

        if mode == "packed":
            data = dict()
            data["poison_setup"] = self.poison_setup
            data["poison_delta"] = poison_delta
            data["poison_ids"] = self.poison_ids
            data["target_images"] = [data for data in self.targetset]
            name = f"{path}poisons_packed_{datetime.date.today()}.pth"
            torch.save([poison_delta, self.poison_ids], os.path.join(path, name))

        elif mode == "limited":
            # Save training set
            names = self.trainset.classes
            for name in names:
                os.makedirs(os.path.join(path, "train", name), exist_ok=True)
                os.makedirs(os.path.join(path, "targets", name), exist_ok=True)
            for input, label, idx in self.trainset:
                lookup = self.poison_lookup.get(idx)
                if lookup is not None:
                    _save_image(
                        input,
                        label,
                        idx,
                        location=os.path.join(path, "train", names[label]),
                        train=True,
                    )
            print("Poisoned training images exported ...")

            # Save secret targets
            for enum, (target, _, idx) in enumerate(self.targetset):
                intended_class = self.poison_setup["intended_class"][enum]
                _save_image(
                    target,
                    intended_class,
                    idx,
                    location=os.path.join(path, "targets", names[intended_class]),
                    train=False,
                )
            print("Target images exported with intended class labels ...")

        elif mode == "full":
            # Save training set
            names = self.trainset.classes
            for name in names:
                os.makedirs(os.path.join(path, "train", name), exist_ok=True)
                os.makedirs(os.path.join(path, "test", name), exist_ok=True)
                os.makedirs(os.path.join(path, "targets", name), exist_ok=True)
            for input, label, idx in self.trainset:
                _save_image(
                    input,
                    label,
                    idx,
                    location=os.path.join(path, "train", names[label]),
                    train=True,
                )
            print("Poisoned training images exported ...")

            for input, label, idx in self.validset:
                _save_image(
                    input,
                    label,
                    idx,
                    location=os.path.join(path, "test", names[label]),
                    train=False,
                )
            print("Unaffected validation images exported ...")

            # Save secret targets
            for enum, (target, _, idx) in enumerate(self.targetset):
                intended_class = self.poison_setup["intended_class"][enum]
                _save_image(
                    target,
                    intended_class,
                    idx,
                    location=os.path.join(path, "targets", names[intended_class]),
                    train=False,
                )
            print("Target images exported with intended class labels ...")

        elif mode == "poison_dataset":
            for input, label, idx in self.trainset:
                lookup = self.poison_lookup.get(idx)
                if lookup is not None:
                    _save_image(
                        input,
                        label,
                        idx,
                        location=path,
                        train=True,
                    )
            print("Poisoned training images exported ...")

        elif mode == "numpy":
            _, h, w = self.trainset[0][0].shape
            training_data = np.zeros([len(self.trainset), h, w, 3])
            labels = np.zeros(len(self.trainset))
            for input, label, idx in self.trainset:
                lookup = self.poison_lookup.get(idx)
                if lookup is not None:
                    input += poison_delta[lookup, :, :, :]
                training_data[idx] = np.asarray(_torch_to_PIL(input))
                labels[idx] = label
            os.makedirs(path, exist_ok=True)

            np.save(os.path.join(path, "poisoned_training_data.npy"), training_data)
            np.save(os.path.join(path, "poisoned_training_labels.npy"), labels)

        elif mode == "furnace-export":
            with open(
                f"furnace_{self.args.dataset}{self.args.model}.pkl", "wb"
            ) as file:
                pickle.dump(
                    [self, poison_delta], file, protocol=pickle.HIGHEST_PROTOCOL
                )

        else:
            raise NotImplementedError()

        print("Dataset fully exported.")
