import argparse
import datetime
import logging
import os
import time
import traceback
import sys
import copy
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import Dataset
# option file should be modified according to your expriment
from options import Option
import torchvision.transforms as transforms
from dataloader import DataLoader
from trainer_direct import Trainer
import shutil
import utils as utils
from quantization_utils.quant_modules import *
from pytorchcv.model_provider import get_model as ptcv_get_model
from conditional_batchnorm import CategoricalConditionalBatchNorm2d
import pickle
from PIL import Image


class Generator(nn.Module):
    def __init__(self, options=None, conf_path=None):
        super(Generator, self).__init__()
        self.settings = options or Option(conf_path)
        self.label_emb = nn.Embedding(self.settings.nClasses, self.settings.latent_dim)
        self.init_size = self.settings.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(self.settings.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks0 = nn.Sequential(
            nn.BatchNorm2d(128),
        )

        self.conv_blocks1 = nn.Sequential(
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.conv_blocks2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, self.settings.channels, 3, stride=1, padding=1),
            nn.Tanh(),
            nn.BatchNorm2d(self.settings.channels, affine=False)
        )

    def forward(self, z, labels):
        gen_input = torch.mul(self.label_emb(labels), z)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks0(out)
        img = nn.functional.interpolate(img, scale_factor=2)
        img = self.conv_blocks1(img)
        img = nn.functional.interpolate(img, scale_factor=2)
        img = self.conv_blocks2(img)
        return img


class Generator_imagenet(nn.Module):
    def __init__(self, options=None, conf_path=None):
        self.settings = options or Option(conf_path)

        super(Generator_imagenet, self).__init__()

        self.init_size = self.settings.img_size // 4
        self.l1 = nn.Sequential(nn.Linear(self.settings.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks0_0 = CategoricalConditionalBatchNorm2d(1000, 128)

        self.conv_blocks1_0 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        self.conv_blocks1_1 = CategoricalConditionalBatchNorm2d(1000, 128, 0.8)
        self.conv_blocks1_2 = nn.LeakyReLU(0.2, inplace=True)

        self.conv_blocks2_0 = nn.Conv2d(128, 64, 3, stride=1, padding=1)
        self.conv_blocks2_1 = CategoricalConditionalBatchNorm2d(1000, 64, 0.8)
        self.conv_blocks2_2 = nn.LeakyReLU(0.2, inplace=True)
        self.conv_blocks2_3 = nn.Conv2d(64, self.settings.channels, 3, stride=1, padding=1)
        self.conv_blocks2_4 = nn.Tanh()
        self.conv_blocks2_5 = nn.BatchNorm2d(self.settings.channels, affine=False)

    def forward(self, z, labels):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks0_0(out, labels)
        img = nn.functional.interpolate(img, scale_factor=2)
        img = self.conv_blocks1_0(img)
        img = self.conv_blocks1_1(img, labels)
        img = self.conv_blocks1_2(img)
        img = nn.functional.interpolate(img, scale_factor=2)
        img = self.conv_blocks2_0(img)
        img = self.conv_blocks2_1(img, labels)
        img = self.conv_blocks2_2(img)
        img = self.conv_blocks2_3(img)
        img = self.conv_blocks2_4(img)
        img = self.conv_blocks2_5(img)
        return img


class direct_dataset(Dataset):
    def __init__(self, logger, dataset):

        self.logger = logger
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        if dataset in ["cifar100", "cifar10"]:
            self.train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=32, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                # transforms.ToTensor(),
                # normalize
            ])
        else:
            self.train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=224, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                # transforms.ToTensor(),
                # normalize
            ])

        self.tmp_data = None
        self.tmp_label = None
        for i in range(1, 5):
            # data!
            # cifar100
            path = "./data_generate/hardsample/cifar100/0.001/4bit/resnet20_cifar100refined_gaussian_getDistilData_hardsample_twolabel_0.6cosineMargin0.021.0interClassMargin0.0augRHFRRC0.5_EMAMSE_doubleOptiEpo1000_patience=50OnlyLOR_4bit" + str(i) + ".pickle"  # 这里需要修改


            self.logger.info(path)
            with open(path, "rb") as fp:  # Pickling
                gaussian_data = pickle.load(fp)
            # import IPython
            # IPython.embed()
            if self.tmp_data is None:
                self.tmp_data = np.concatenate(gaussian_data, axis=0)
            else:
                self.tmp_data = np.concatenate((self.tmp_data, np.concatenate(gaussian_data, axis=0)))

            # label!
            # cifar100
            path = "./data_generate/hardsample/cifar100/0.001/4bit/resnet20_cifar100labels_list_getDistilData_hardsample_twolabel_0.6cosineMargin0.021.0interClassMargin0.0augRHFRRC0.5_EMAMSE_doubleOptiEpo1000_patience=50OnlyLOR_4bit" + str(i) + ".pickle"

            self.logger.info(path)
            with open(path, "rb") as fp:  # Pickling
                labels_list = pickle.load(fp)
            if self.tmp_label is None:
                self.tmp_label = np.concatenate(labels_list, axis=0)
            else:
                self.tmp_label = np.concatenate((self.tmp_label, np.concatenate(labels_list, axis=0)))

        assert len(self.tmp_label) == len(self.tmp_data)
        print(self.tmp_data.shape, self.tmp_label.shape)
        print('direct datset image number', len(self.tmp_label))

    def __getitem__(self, index):
        img = self.tmp_data[index]
        label = self.tmp_label[index]
        img = self.train_transform(torch.from_numpy(img))
        return img, label

    def __len__(self):
        return len(self.tmp_label)


class ExperimentDesign:
    def __init__(self, generator=None, options=None, conf_path=None, args=None):
        self.settings = options or Option(conf_path)
        self.generator = generator
        self.train_loader = None
        self.test_loader = None
        self.model = None
        self.model_teacher = None

        self.optimizer_state = None
        self.trainer = None
        self.start_epoch = 0
        self.test_input = None

        self.args = args

        self.unfreeze_Flag = True

        os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"

        self.settings.set_save_path()
        shutil.copyfile(conf_path, os.path.join(self.settings.save_path, conf_path))
        shutil.copyfile('./main_direct.py', os.path.join(self.settings.save_path, 'main_direct.py'))
        shutil.copyfile('./trainer_direct.py', os.path.join(self.settings.save_path, 'trainer_direct.py'))
        self.logger = self.set_logger()
        self.settings.paramscheck(self.logger)

        self.prepare()

    def set_logger(self):
        logger = logging.getLogger('baseline')
        file_formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
        console_formatter = logging.Formatter('%(message)s')
        # file log
        file_handler = logging.FileHandler(os.path.join(self.settings.save_path, "train_test.log"))
        file_handler.setFormatter(file_formatter)

        # console log
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(console_formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        logger.setLevel(logging.INFO)
        return logger

    def prepare(self):
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._replace()
        self.logger.info(self.model)
        self._set_trainer()

    def _set_gpu(self):
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count() - 1, "Invalid GPU ID"
        cudnn.benchmark = True

    def _set_dataloader(self):
        # create data loader
        data_loader = DataLoader(dataset=self.settings.dataset,
                                 batch_size=self.settings.batchSize,
                                 data_path=self.settings.dataPath,
                                 n_threads=self.settings.nThreads,
                                 ten_crop=self.settings.tenCrop,
                                 logger=self.logger)

        self.train_loader, self.test_loader = data_loader.getloader()

    def _set_model(self):
        if self.settings.dataset in ["cifar100", "cifar10"]:
            self.test_input = Variable(torch.randn(1, 3, 32, 32).cuda())
            self.model = ptcv_get_model(self.args.model_name, pretrained=True)
            self.model_teacher = ptcv_get_model(self.args.model_name, pretrained=True)
            self.model_teacher.eval()

        elif self.settings.dataset in ["imagenet"]:
            self.test_input = Variable(torch.randn(1, 3, 224, 224).cuda())
            self.model = ptcv_get_model(self.args.model_name, pretrained=True)
            self.model_teacher = ptcv_get_model(self.args.model_name, pretrained=True)
            self.model_teacher.eval()

        else:
            assert False, "unsupport data set: " + self.settings.dataset

    def _set_trainer(self):
        # set lr master
        lr_master_S = utils.LRPolicy(self.settings.lr_S,
                                     self.settings.nEpochs,
                                     self.settings.lrPolicy_S)
        lr_master_G = utils.LRPolicy(self.settings.lr_G,
                                     self.settings.nEpochs,
                                     self.settings.lrPolicy_G)

        params_dict_S = {
            'step': self.settings.step_S,
            'decay_rate': self.settings.decayRate_S
        }

        params_dict_G = {
            'step': self.settings.step_G,
            'decay_rate': self.settings.decayRate_G
        }

        lr_master_S.set_params(params_dict=params_dict_S)
        lr_master_G.set_params(params_dict=params_dict_G)

        # set trainer
        self.trainer = Trainer(
            model=self.model,
            model_teacher=self.model_teacher,
            generator=self.generator,
            train_loader=self.train_loader,
            test_loader=self.test_loader,
            lr_master_S=lr_master_S,
            lr_master_G=lr_master_G,
            settings=self.settings,
            logger=self.logger,
            opt_type=self.settings.opt_type,
            optimizer_state=self.optimizer_state,
            run_count=self.start_epoch)

    def quantize_model(self, model):
        """
        Recursively quantize a pretrained single-precision model to int8 quantized model
        model: pretrained single-precision model
        """

        weight_bit = self.settings.qw
        act_bit = self.settings.qa

        # quantize convolutional and linear layers
        if type(model) == nn.Conv2d:
            quant_mod = Quant_Conv2d(weight_bit=weight_bit)
            quant_mod.set_param(model)
            return quant_mod
        elif type(model) == nn.Linear:
            quant_mod = Quant_Linear(weight_bit=weight_bit)
            quant_mod.set_param(model)
            return quant_mod

        # quantize all the activation
        elif type(model) == nn.ReLU or type(model) == nn.ReLU6:
            return nn.Sequential(*[model, QuantAct(activation_bit=act_bit)])

        # recursively use the quantized module to replace the single-precision module
        elif type(model) == nn.Sequential:
            mods = []
            for n, m in model.named_children():
                mods.append(self.quantize_model(m))
            return nn.Sequential(*mods)
        else:
            q_model = copy.deepcopy(model)
            for attr in dir(model):
                mod = getattr(model, attr)
                if isinstance(mod, nn.Module) and 'norm' not in attr:
                    setattr(q_model, attr, self.quantize_model(mod))
            return q_model

    def _replace(self):
        self.model = self.quantize_model(self.model)

    def freeze_model(self, model):
        """
        freeze the activation range
        """
        if type(model) == QuantAct:
            model.fix()
        elif type(model) == nn.Sequential:
            for n, m in model.named_children():
                self.freeze_model(m)
        else:
            for attr in dir(model):
                mod = getattr(model, attr)
                if isinstance(mod, nn.Module) and 'norm' not in attr:
                    self.freeze_model(mod)
            return model

    def unfreeze_model(self, model):
        """
        unfreeze the activation range
        """
        if type(model) == QuantAct:
            model.unfix()
        elif type(model) == nn.Sequential:
            for n, m in model.named_children():
                self.unfreeze_model(m)
        else:
            for attr in dir(model):
                mod = getattr(model, attr)
                if isinstance(mod, nn.Module) and 'norm' not in attr:
                    self.unfreeze_model(mod)
            return model

    def generate_linear_schedule(self, T, low, high):
        return np.linspace(low, high, T)

    def initialize_diffusion_constants(self, betas):
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas)
        # epsilon = 1e-10
        # 添加一个小的偏移量来避免除以零的情况
        # alphas = np.clip(alphas, epsilon, None)

        sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
        sqrt_one_minus_alphas_cumprod = np.sqrt(1 - alphas_cumprod)
        reciprocal_sqrt_alphas = np.sqrt(1 / alphas)
        sigma = np.sqrt(betas)

        return (self.to_torch(betas), self.to_torch(alphas), self.to_torch(alphas_cumprod),
                self.to_torch(sqrt_alphas_cumprod), self.to_torch(sqrt_one_minus_alphas_cumprod),
                self.to_torch(reciprocal_sqrt_alphas), self.to_torch(sigma))

    def to_torch(self, x):
        return torch.tensor(x, dtype=torch.float32).cuda()

    def extract(self, a, t, x_shape):
        b, *_ = t.shape
        out = a.gather(-1, t)
        return out.reshape(b, *((1,) * (len(x_shape) - 1)))

    def perturb_x(self, x, t, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod):
        return (
                self.extract(sqrt_alphas_cumprod, t, x.shape) * x +
                self.extract(sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
        )

    def forward_diffusion_process(self, x, betas, t):
        # Initialize diffusion constants and coefficients
        (betas, alphas, alphas_cumprod, sqrt_alphas_cumprod,
         sqrt_one_minus_alphas_cumprod, reciprocal_sqrt_alphas, sigma) = self.initialize_diffusion_constants(betas)

        b, c, h, w = x.shape

        t_batch = torch.tensor([t]).repeat(b).cuda()
        noise = torch.randn_like(x)
        x_t = self.perturb_x(x, t_batch, noise, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)

        return x_t

    def save_noised_images(self, dataloader, output_dir, total_steps=80):
        os.makedirs(output_dir, exist_ok=True)
        betas = self.generate_linear_schedule(total_steps - 1, 0.0001, 0.02)

        for T in range(total_steps):
            noise_output_dir = os.path.join(output_dir, f'noise_level_{T}')
            os.makedirs(noise_output_dir, exist_ok=True)

            if T == 10:
                break

            for batch_idx, (images, labels) in enumerate(dataloader):
                images, labels = images.cuda(), labels.cuda()

                if T == 0:
                    images_T = images
                else:
                    images_T = self.forward_diffusion_process(images, betas, T - 1)

                torch.save((images_T.cpu(), labels.cpu()), os.path.join(noise_output_dir, f'batch_{batch_idx}.pt'))

                print(f'Saved batch {batch_idx} at noise level {T} to {noise_output_dir}')

    def run(self):
        dataset = direct_dataset(self.logger, self.settings.dataset)

        direct_dataload = torch.utils.data.DataLoader(dataset,
                                                      batch_size=min(self.settings.batchSize, len(dataset)),
                                                      shuffle=True,
                                                      num_workers=0,
                                                      pin_memory=True,
                                                      drop_last=True)
        output_dir = "./data_generate/hardsample/cifar100/0.001/4bit/noise/"
        self.save_noised_images(direct_dataload, output_dir, 80)



def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'
    parser = argparse.ArgumentParser(description='Baseline')
    parser.add_argument('--conf_path', type=str, metavar='conf_path',
                        help='input the path of config file')
    parser.add_argument('--id', type=int, metavar='experiment_id',
                        help='Experiment ID')
    parser.add_argument('--model_name', type=str, metavar='model_name',
                        help='model_name')
    args = parser.parse_args()

    option = Option(args.conf_path)
    option.manualSeed = args.id + 1
    option.experimentID = option.experimentID + "{:0>2d}_repeat".format(args.id)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(option.manualSeed)
    torch.manual_seed(option.manualSeed)
    np.random.seed(option.manualSeed)

    if option.dataset in ["cifar100", "cifar10"]:
        generator = Generator(option)
    elif option.dataset in ["imagenet"]:
        generator = Generator_imagenet(option)
    else:
        assert False, "invalid data set"

    experiment = ExperimentDesign(generator, option, args.conf_path, args)
    experiment.run()


if __name__ == '__main__':
    main()
