import os
import time
import pickle
import sys
sys.path.append('..')

import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
import torch.utils.data.sampler as sp

import data


class ClassifierTrainerWM:
    def __init__(self, opt, model, dataset, wm_dataset, model_name, optim='adam',
                 continue_train=True, print_freq=1):
        self.opt        = opt
        self.model      = model
        self.dataset    = dataset
        self.wm_dataset = wm_dataset
        self.model_name = model_name
        self.optim      = optim
        self.continue_train = continue_train
        self.print_freq = print_freq

        if opt.victim_dataset == 'cifar10':
            self.n_w_ratio = 4
            self.epochs = 50
            self.w_epochs = 6
            self.t_lr = 0.1
            self.w_lr = 0.01
            self.maxiter = 3 #10
        elif opt.victim_dataset == 'cifar100':
            self.n_w_ratio = 15
            self.epochs = 100
            self.w_epochs = 8
            self.t_lr = 0.01
            self.w_lr = 0.01
            self.maxiter = 100
        else:
            raise Exception(f'Unimplemented victim dataset: {opt.victim_dataset}')

        # self.layers = 18
        self.factors = [1e5, 1e5, 1e5]
        self.temperatures = torch.Tensor([1, 1, 1])
        self.threshold = 0.1
        self.wm_source = 8
        self.wm_target = 0
        # self.distrib = "out"
        self.metric = "cosine"
        self.snnl_factor = 300#0.3

        self.trigger_dataset = data.LabeledDataset()


    def weight_init(self, m):
        if isinstance(m, nn.Linear):
            # nn.init.constant(m.weight, 1e-2)
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias,0)
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out")
            # nn.init.constant(m.weight, 1e-3)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 2e-1)
            nn.init.constant_(m.bias, 0)

    def train(self,
        lr = 0.001,
        weight_decay=5e-4
    ):
        batch_size = 128
        self.model.apply(self.weight_init)

        if self.continue_train:
            model_exists = False
            ckpt_path = f'{self.opt.data_dir}checkpoints/{self.model_name}_init_state_dict'
            if os.path.exists(ckpt_path):
                self.model.load_state_dict(torch.load(ckpt_path))
                model_exists = True

        if self.opt.use_gpu:
            self.model.cuda()

        if self.optim == 'adam':
            optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=lr,
                eps=1e-5
            )
        else:
            optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=lr,
                momentum=0.9,
                weight_decay=weight_decay
            )

        index_list = []
        for i in range(len(self.dataset.train_dataset)):
            if self.dataset.train_dataset[i][1] == self.wm_target:
                index_list.append(i)
        target_dataset = torch.utils.data.Subset(self.dataset.train_dataset, index_list)
        target_dataset_no_aug = torch.utils.data.Subset(self.dataset.train_dataset_no_aug, index_list)
        self.wm_dataset.train_dataset.data = self.wm_dataset.train_dataset.data[
            self.wm_dataset.train_dataset.labels == self.wm_source]
        self.wm_dataset.train_dataset.labels = self.wm_dataset.train_dataset.labels[
            self.wm_dataset.train_dataset.labels == self.wm_source]

        dataloader = self.dataset.train_dataloader()
        # wm_dataloader = self.wm_dataset.train_dataloader()
        wm_dataloader = torch.utils.data.DataLoader(
            self.wm_dataset.train_dataset,
            batch_size=int(0.5*batch_size),
            shuffle=True,
            num_workers=4,
            drop_last=True
        )
        # target_dataloader = torch.utils.data.DataLoader(
        #     target_dataset,
        #     batch_size=int(0.5*batch_size),
        #     shuffle=True,
        #     num_workers=4,
        #     drop_last=True
        # )
        target_dataloader_no_aug = torch.utils.data.DataLoader(
            target_dataset_no_aug,
            batch_size=int(0.5*batch_size),
            shuffle=True,
            num_workers=4,
            drop_last=True
        )

        self.temperatures = Variable(self.temperatures, requires_grad=True)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 40, gamma=0.2)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, threshold=0.005, factor=0.2)
        # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs=30,
        #                                                 steps_per_epoch=len(dataloader))

        ce_loss_func = nn.CrossEntropyLoss()
        w_0 = torch.zeros(batch_size)

        # initial train
        starting_epoch_n = 0
        training_was_in_progress = False
        if self.continue_train:
            root_optimizer_ckpt_path = f'optimizer_for_{self.model_name}_init_state_dict'
            optimizer_ckpt_path = root_optimizer_ckpt_path
            for filename in os.listdir(os.path.join(self.opt.data_dir, 'checkpoints')):
                if optimizer_ckpt_path in filename:
                    training_was_in_progress = True
                    optimizer_ckpt_path = filename

            if training_was_in_progress:
                optimizer.load_state_dict(torch.load(os.path.join(
                    self.opt.data_dir, f'checkpoints/{optimizer_ckpt_path}')))
                # if model exists and no optimizer ckpt is found, train from the first epoch.
                starting_epoch_n = int(optimizer_ckpt_path.split('_')[-1])

        best_acc = 0.

        for epoch in range(starting_epoch_n + 1, self.epochs + 1):
            start_time = time.perf_counter()
            current_loss = 0.
            self.model.train()

            for iter_n, (images, targets) in enumerate(dataloader):
                if self.opt.use_gpu:
                    images = images.cuda()
                    targets = targets.cuda()
                predictions = self.model(images)
                loss = self.model.total_loss(predictions, targets, w_0, self.factors, self.temperatures)
                self.model.zero_grad()
                loss.backward()
                # clip_gradient(optimizer, 0.1)
                optimizer.step()
                # optimizer.zero_grad()

            current_loss /= iter_n + 1
            acc = predictions[-1].max(1)[1].eq(targets)
            acc = acc.float().mean().detach().cpu()
            end_time = time.perf_counter()
            if self.opt.print_epoch and epoch % self.print_freq == 0:
                print('Epoch %d/%d | Iter %d | Acc %.5f | Loss %.5f | Time %.2fs' %
                      (epoch, self.epochs, iter_n, acc, loss, end_time - start_time))

            accs = self.evaluate()
            if self.continue_train:
                torch.save(self.model.state_dict(), ckpt_path)
                # self.model.save(ckpt_path)
                if accs > best_acc:
                    best_acc = accs
                    torch.save(self.model.state_dict(), ckpt_path)
                    # self.model.save(ckpt_path)
                new_checkpoint_path = f'{root_optimizer_ckpt_path}_{epoch}'
                torch.save(optimizer.state_dict(), '%scheckpoints/%s'%(self.opt.data_dir,new_checkpoint_path))
                if os.path.exists('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path)):
                    os.unlink('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path))
                optimizer_ckpt_path = new_checkpoint_path

            # scheduler.step()
            # scheduler.step(loss)
        self.evaluate()

        # generate trigger
        w_label = torch.cat((torch.ones(int(0.5*batch_size)), torch.zeros(int(0.5*batch_size))))
        # trigger_path = os.path.join(self.opt.data_dir, f'checkpoints/{self.model_name}_trigger.pkl')
        trigger_path = os.path.join(
            self.opt.data_dir, f'checkpoints/{self.model_name}_{self.threshold}_{self.maxiter}_trigger.pkl')
        if os.path.exists(trigger_path):
            with open(trigger_path, 'rb') as pf:
                self.trigger_dataset.items = pickle.load(pf)
        else:
            wm_data_iter = iter(wm_dataloader)
            self.model.eval()
            for iter_n, (target_data, _) in enumerate(target_dataloader_no_aug):
                try:
                    images, _ = wm_data_iter.next()
                except:
                    wm_data_iter = iter(wm_dataloader)
                    images, _ = wm_data_iter.next()
                step = 0
                if self.opt.use_gpu:
                    images = images.cuda()
                    target_data = target_data.cuda()
                    # targets = targets.cuda()
                current_trigger = Variable(images.detach(), requires_grad=True)
                # current_trigger.retain_grad()
                for epoch in range(self.maxiter):
                    # update CE gradient
                    while self.batch_watermark_evaluate(current_trigger) > self.threshold and step < 50:
                        # print(self.batch_watermark_evaluate(current_trigger))
                        step += 1
                        batch_data = torch.cat((current_trigger, current_trigger))
                        predictions = self.model(batch_data)
                        ce_loss = ce_loss_func(predictions[-1], torch.full((batch_data.shape[0],), self.wm_target).cuda())
                        ce_loss.backward()
                        current_trigger = torch.clamp(
                            current_trigger.detach()+self.w_lr*current_trigger.grad.data.sign().detach(), min=0, max=1).detach()
                        if self.opt.use_gpu:
                            current_trigger = current_trigger.cuda()
                        current_trigger.requires_grad = True
                        # current_trigger.grad.data.zero_()
                    # update SNNL gradient
                    batch_data = torch.cat((current_trigger, target_data))
                    predictions = self.model(batch_data)
                    loss_list = self.model.snnl(predictions, w_label, self.temperatures)
                    snnl_loss = loss_list[0] + loss_list[1] + loss_list[2]
                    snnl_loss.backward()
                    current_trigger = torch.clamp(
                        current_trigger.detach() + self.w_lr * current_trigger.grad.data.sign(), min=0, max=1).detach()
                    current_trigger.requires_grad = True
                    # current_trigger.retain_grad()
                for i in range(5):
                    batch_data = torch.cat((current_trigger, current_trigger))
                    predictions = self.model(batch_data)
                    ce_loss = ce_loss_func(predictions[-1], torch.full((batch_data.shape[0],), self.wm_target).cuda())
                    ce_loss.backward()
                    current_trigger = torch.clamp(
                        current_trigger.detach() + self.w_lr * torch.sign(current_trigger.grad).detach(), min=0, max=1).detach()
                    current_trigger.requires_grad = True

                print(f'[trigger] batch {iter_n + 1}/{int(len(target_dataset)/(0.5*batch_size))}: ' +
                      f'watermark extraction success rate: {self.batch_watermark_evaluate(current_trigger)}')
                for trigger_item in current_trigger:
                    self.trigger_dataset.items.append((trigger_item.detach().cpu(), self.wm_target))

            # save trigger
            with open(trigger_path, 'wb') as pf:
                pickle.dump(self.trigger_dataset.items, pf)

        # train with entangling watermark
        if self.continue_train:
            ckpt_path = f'{self.opt.data_dir}checkpoints/{self.model_name}_wm_{self.threshold}_{self.maxiter}_{self.snnl_factor}_state_dict'
            if os.path.exists(ckpt_path):
                self.model.load_state_dict(torch.load(ckpt_path))
        starting_epoch_n = 0
        training_was_in_progress = False
        if self.continue_train:
            # load previous optimizer checkpoint
            init_optimizer_ckpt_path = f'optimizer_for_{self.model_name}_init_state_dict'
            for filename in os.listdir(os.path.join(self.opt.data_dir, 'checkpoints')):
                if init_optimizer_ckpt_path in filename:
                    optimizer_ckpt_path = filename
                    optimizer.load_state_dict(torch.load(os.path.join(
                        self.opt.data_dir, f'checkpoints/{optimizer_ckpt_path}')))
                    print('previous optimizer loaded')

            root_optimizer_ckpt_path = f'optimizer_for_{self.model_name}_wm_{self.threshold}_{self.maxiter}_{self.snnl_factor}_state_dict'
            optimizer_ckpt_path = root_optimizer_ckpt_path
            for filename in os.listdir(os.path.join(self.opt.data_dir, 'checkpoints')):
                if optimizer_ckpt_path in filename:
                    training_was_in_progress = True
                    optimizer_ckpt_path = filename

            if training_was_in_progress:
                optimizer.load_state_dict(torch.load(os.path.join(
                    self.opt.data_dir, f'checkpoints/{optimizer_ckpt_path}')))
                # if model exists and no optimizer ckpt is found, train from the first epoch.
                starting_epoch_n = int(optimizer_ckpt_path.split('_')[-1])

        trigger_dataloader = torch.utils.data.DataLoader(
            self.trigger_dataset,
            batch_size=int(0.5 * batch_size),
            shuffle=True,
            num_workers=4,
            drop_last=True
        )
        train_data_iter = iter(dataloader)
        print(len(self.trigger_dataset))

        best_extract_rate = 0.
        best_acc = 0.
        for epoch in range(starting_epoch_n+1, int(self.w_epochs/2)+1):
            self.model.train()
            j = 0
            trigger_data_iter = iter(trigger_dataloader)
            # wm_data_iter = iter(wm_dataloader)
            # target_data_iter = iter(target_dataloader)
            for batch, (target_data, _) in enumerate(target_dataloader_no_aug):
                if self.n_w_ratio >= 1:
                    for i in range(int(self.n_w_ratio)):
                        try:
                            train_images, train_labels = train_data_iter.next()
                        except:
                            train_data_iter = iter(dataloader)
                            train_images, train_labels = train_data_iter.next()
                        if self.opt.use_gpu:
                            train_images = train_images.cuda()
                            train_labels = train_labels.cuda()
                        predictions = self.model(train_images)
                        loss = self.model.total_loss(
                            predictions, train_labels, w_0, self.factors, self.temperatures)
                        self.model.zero_grad()
                        loss.backward()
                        optimizer.step()
                    # accs = self.evaluate(print_result=False)
                    # extract_success_rate = self.watermark_evaluate()
                    # print(f'[watermark] epoch {epoch}/{int(self.w_epochs / 2)}: ' +
                    #       f'accuracy: {accs} | ' +
                    #       f'watermark extraction success rate: {extract_success_rate}')
                if self.n_w_ratio > 0 and self.n_w_ratio % 1 != 0 and self.n_w_ratio * batch >= j:
                    pass
                trigger_data, _ = trigger_data_iter.next()
                # target_data, _ = target_data_iter.next()
                if self.opt.use_gpu:
                    trigger_data = trigger_data.cuda()
                    target_data = target_data.cuda()
                batch_data = torch.cat((trigger_data, target_data))
                predictions = self.model(batch_data)
                targets = torch.full((batch_size,), self.wm_target).cuda()
                loss = self.model.total_loss(
                    predictions, targets, w_label, self.factors, self.temperatures, snnl_factor=self.snnl_factor)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                self.temperatures = self.temperatures.detach() - self.t_lr * self.temperatures.grad[0].detach()
                self.temperatures.requires_grad = True
                # print(self.batch_watermark_evaluate(trigger_data))
                # self.temperatures.grad.data.zero_()
            accs = self.evaluate(print_result=False)
            extract_success_rate = self.watermark_evaluate()
            print(f'[watermark] epoch {epoch}/{int(self.w_epochs/2)}: ' +
                  f'accuracy: {accs} | ' +
                  f'watermark extraction success rate: {extract_success_rate}')

            if self.continue_train:
                torch.save(self.model.state_dict(), ckpt_path)
                # self.model.save(ckpt_path)
                if accs > best_acc and extract_success_rate > best_extract_rate:
                    torch.save(self.model.state_dict(), ckpt_path)
                if accs > best_acc:
                    best_acc = accs
                if extract_success_rate > best_extract_rate:
                    best_extract_rate = extract_success_rate
                new_checkpoint_path = f'{root_optimizer_ckpt_path}_{epoch}'
                torch.save(optimizer.state_dict(), '%scheckpoints/%s'%(self.opt.data_dir,new_checkpoint_path))
                if os.path.exists('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path)):
                    os.unlink('%scheckpoints/%s'%(self.opt.data_dir,optimizer_ckpt_path))
                optimizer_ckpt_path = new_checkpoint_path

        accs = self.evaluate(print_result=False)
        extract_success_rate = self.watermark_evaluate()
        print(f'[watermarked victim] accuracy: {accs} | ' +
              f'watermark extraction success rate: {extract_success_rate}')

        return self.model

    def evaluate(self, print_result=True):
        self.model.eval()
        accs = 0
        n_samples = 0
        dataloader = self.dataset.test_dataloader()
        for iter_n, batch in enumerate(dataloader):
            images = batch[0]
            targets = batch[1]
            n_samples += targets.shape[0]
            if self.opt.use_gpu:
                images = images.cuda()
                targets = targets.cuda()
            with torch.no_grad():
                predictions = self.model(images)
                acc = predictions[-1].max(1)[1].eq(targets).float().sum()
                acc = acc.detach().cpu()
            accs += acc
        accs /= n_samples
        if print_result:
            print('%s accuracy: %.5f'%(self.model_name, accs))
        return accs

    def batch_watermark_evaluate(self, trigger_set):
        self.model.eval()
        if self.opt.use_gpu:
            trigger_set = trigger_set.cuda()
            self.model.cuda()
        with torch.no_grad():
            predictions = self.model(trigger_set)
            success = (predictions[-1].max(1)[1] == self.wm_target).float().sum().detach().cpu()
        return success / trigger_set.shape[0]

    def watermark_evaluate(self, sample_size=400):
        data_list = [i for i in range(0, sample_size)]
        batch_size = 50
        dataloader = torch.utils.data.DataLoader(
            self.trigger_dataset, batch_size=batch_size,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        success = 0.
        total = 0.
        for _, (trigger_data, _) in enumerate(dataloader):
            total += trigger_data.shape[0]
            success += self.batch_watermark_evaluate(trigger_data) * trigger_data.shape[0]
        return success / total
