import paddle
import paddle.vision.transforms as transforms
from paddle import nn, Tensor
import config
import sys

# sys.path.insert(0, "../..")

from classifier_models import *
import os
# import matplotlib.pyplot as plt
import numpy as np
from utils import progress_bar
from networks.models import NetC_MNIST, Normalize, Denormalize
from dataloader import get_dataloader


class RegressionModel(nn.Layer):
    def __init__(self, opt, init_mask, init_pattern):
        self._EPSILON = opt.EPSILON
        super(RegressionModel, self).__init__()
        mask_attr = paddle.ParamAttr(initializer=nn.initializer.Assign(init_mask))
        self.mask_tanh = paddle.create_parameter(init_mask.shape, paddle.float32,  attr=(mask_attr))
        pattern_attr = paddle.ParamAttr(initializer=nn.initializer.Assign(init_pattern))
        self.pattern_tanh = paddle.create_parameter(init_pattern.shape, paddle.float32,  attr=(pattern_attr))
        self.classifier = self._get_classifier(opt)
        self.normalizer = self._get_normalize(opt)
        self.denormalizer = self._get_denormalize(opt)

    def forward(self, x):
        mask = self.get_raw_mask()
        pattern = self.get_raw_pattern()
        if self.normalizer:
            pattern = self.normalizer(pattern)
        x = (1 - mask) * x + mask * pattern
        return self.classifier(x)

    def get_raw_mask(self):
        mask = nn.Tanh()(self.mask_tanh)
        return mask / (2 + self._EPSILON) + 0.5

    def get_raw_pattern(self):
        pattern = nn.Tanh()(self.pattern_tanh)
        return pattern / (2 + self._EPSILON) + 0.5

    def _get_classifier(self, opt):
        if opt.dataset == "mnist":
            classifier = NetC_MNIST()
        elif opt.dataset == "cifar10":
            classifier = PreActResNet18(num_classes=10)
        elif opt.dataset == "gtsrb":
            classifier = PreActResNet18(num_classes=43)
        elif opt.dataset == "celeba":
            classifier = ResNet18()
        else:
            raise Exception("Invalid Dataset")
        # Load pretrained classifie
        ckpt_path = os.path.join(
            opt.checkpoints, opt.dataset, "target_morph_{}_{}_{}_paddle.pth.pdmodel".format(opt.attack_method, opt.attack_ratio, opt.attack_type)
        )
        print('load model from:{}'.format(ckpt_path))

        state_dict = paddle.load(ckpt_path)
        classifier.set_state_dict(state_dict["netC"])
        for param in classifier.parameters():
            param.stop_gradient = True
        classifier.eval()
        return classifier

    def _get_denormalize(self, opt):
        if opt.dataset == "cifar10":
            denormalizer = Denormalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif opt.dataset == "mnist":
            denormalizer = Denormalize(opt, [0.5], [0.5])
        elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
            denormalizer = None
        else:
            raise Exception("Invalid dataset")
        return denormalizer

    def _get_normalize(self, opt):
        import paddle.vision.transforms as transforms
        if opt.dataset == "cifar10":
            normalizer = transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
            # Normalize(opt, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif opt.dataset == "mnist":
            normalizer = Normalize(opt, [0.5], [0.5])
        elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
            normalizer = None
        else:
            raise Exception("Invalid dataset")
        return normalizer


class Recorder:
    def __init__(self, opt):
        super().__init__()

        # Best optimization results
        self.mask_best = None
        self.pattern_best = None
        self.reg_best = float("inf")

        # Logs and counters for adjusting balance cost
        self.logs = []
        self.cost_set_counter = 0
        self.cost_up_counter = 0
        self.cost_down_counter = 0
        self.cost_up_flag = False
        self.cost_down_flag = False

        # Counter for early stop
        self.early_stop_counter = 0
        self.early_stop_reg_best = self.reg_best

        # Cost
        self.cost = opt.init_cost
        self.cost_multiplier_up = opt.cost_multiplier
        self.cost_multiplier_down = opt.cost_multiplier ** 1.5

    def reset_state(self, opt):
        self.cost = opt.init_cost
        self.cost_up_counter = 0
        self.cost_down_counter = 0
        self.cost_up_flag = False
        self.cost_down_flag = False
        print("Initialize cost to {:f}".format(self.cost))

    def save_result_to_dir(self, opt):
        result_dir = os.path.join(opt.result, opt.dataset)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        result_dir = os.path.join(result_dir, opt.attack_mode)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        result_dir = os.path.join(result_dir, str(opt.target_label))
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        pattern_best = self.pattern_best
        mask_best = self.mask_best
        trigger = pattern_best * mask_best

        path_mask = os.path.join(result_dir, "mask.png")
        path_pattern = os.path.join(result_dir, "pattern.png")
        path_trigger = os.path.join(result_dir, "trigger.png")


def train(opt, init_mask, init_pattern):

    test_dataloader = get_dataloader(opt, train=False)

    # Build regression model
    regression_model = RegressionModel(opt, init_mask, init_pattern)

    # Set optimizer
    optimizerR = paddle.optimizer.Adam(parameters=regression_model.parameters(), learning_rate=opt.lr, beta1=0.5, beta2=0.9)

    # Set recorder (for recording best result)
    recorder = Recorder(opt)

    for epoch in range(opt.epoch):
        early_stop = train_step(regression_model, optimizerR, test_dataloader, recorder, epoch, opt)
        if early_stop:
            break

    # Save result to dir
    recorder.save_result_to_dir(opt)

    return recorder, opt


def train_step(regression_model, optimizerR, dataloader, recorder, epoch, opt):
    print("Epoch {} - Label: {} | {} - {}:".format(epoch, opt.target_label, opt.dataset, opt.attack_mode))
    # Set losses
    cross_entropy = nn.CrossEntropyLoss()
    regression_model.train()
    total_pred = 0
    true_pred = 0

    # Record loss for all mini-batches
    loss_ce_list = []
    loss_reg_list = []
    loss_list = []
    loss_acc_list = []

    # Set inner early stop flag
    inner_early_stop_flag = False
    for batch_idx, batch in enumerate(dataloader):
        # Forwarding and update model
        optimizerR.clear_grad()
        inputs = batch['input']
        sample_num = inputs.shape[0]
        total_pred += sample_num
        target_labels = paddle.ones([sample_num], dtype=paddle.int64) * opt.target_label
        predictions = regression_model(inputs)

        loss_ce = cross_entropy(predictions, target_labels)
        loss_reg = paddle.norm(regression_model.get_raw_mask(), opt.use_norm)
        cost = recorder.cost * loss_reg
        total_loss = loss_ce  + cost
        total_loss.backward()
        optimizerR.step()

        # Record minibatch information to list
        minibatch_accuracy = paddle.sum(paddle.to_tensor(paddle.argmax(predictions, axis=1) == target_labels, dtype=paddle.int64)).detach()
        minibatch_accuracy = minibatch_accuracy  * 100.0 / sample_num
        loss_ce_list.append(loss_ce.detach())
        loss_reg_list.append(loss_reg.detach())
        loss_list.append(total_loss.detach())
        loss_acc_list.append(minibatch_accuracy)
        
        true_pred +=  paddle.sum(paddle.to_tensor(paddle.argmax(predictions, axis=1) == target_labels, dtype=paddle.int64)).detach()
        progress_bar(batch_idx, len(dataloader))

    loss_ce_list = paddle.stack(loss_ce_list)
    loss_reg_list = paddle.stack(loss_reg_list)
    loss_list = paddle.stack(loss_list)
    loss_acc_list = paddle.stack(loss_acc_list)

    avg_loss_ce = paddle.mean(loss_ce_list)
    avg_loss_reg = paddle.mean(loss_reg_list)
    avg_loss = paddle.mean(loss_list)
    avg_loss_acc = paddle.mean(loss_acc_list)

    # Check to save best mask or not
    if avg_loss_acc >= opt.atk_succ_threshold and avg_loss_reg < recorder.reg_best:
        recorder.mask_best = regression_model.get_raw_mask().detach()
        recorder.pattern_best = regression_model.get_raw_pattern().detach()
        recorder.reg_best = avg_loss_reg.item()
        recorder.save_result_to_dir(opt)
        print(" Updated !!!")

    # Show information
    print(
        "  Result: Accuracy: {:.3f} | Cross Entropy Loss: {:.6f} | Reg Loss: {:.6f} | Reg best: {:.6f}".format(
            (true_pred * 100.0 / total_pred).item(), avg_loss_ce.item(), avg_loss_reg.item(), recorder.reg_best
        )
    )

    # Check early stop
    if opt.early_stop:
        if recorder.reg_best < float("inf"):
            if recorder.reg_best >= opt.early_stop_threshold * recorder.early_stop_reg_best:
                recorder.early_stop_counter += 1
            else:
                recorder.early_stop_counter = 0

        recorder.early_stop_reg_best = min(recorder.early_stop_reg_best, recorder.reg_best)

        if (
            recorder.cost_down_flag
            and recorder.cost_up_flag
            and recorder.early_stop_counter >= opt.early_stop_patience
        ):
            print("Early_stop !!!")
            inner_early_stop_flag = True

    if not inner_early_stop_flag:
        # Check cost modification
        if recorder.cost == 0 and avg_loss_acc >= opt.atk_succ_threshold:
            recorder.cost_set_counter += 1
            if recorder.cost_set_counter >= opt.patience:
                recorder.reset_state(opt)
        else:
            recorder.cost_set_counter = 0

        if avg_loss_acc >= opt.atk_succ_threshold:
            recorder.cost_up_counter += 1
            recorder.cost_down_counter = 0
        else:
            recorder.cost_up_counter = 0
            recorder.cost_down_counter += 1

        if recorder.cost_up_counter >= opt.patience:
            recorder.cost_up_counter = 0
            print("Up cost from {} to {}".format(recorder.cost, recorder.cost * recorder.cost_multiplier_up))
            recorder.cost *= recorder.cost_multiplier_up
            recorder.cost_up_flag = True

        elif recorder.cost_down_counter >= opt.patience:
            recorder.cost_down_counter = 0
            print("Down cost from {} to {}".format(recorder.cost, recorder.cost / recorder.cost_multiplier_down))
            recorder.cost /= recorder.cost_multiplier_down
            recorder.cost_down_flag = True

        # Save the final version
        if recorder.mask_best is None:
            recorder.mask_best = regression_model.get_raw_mask().detach()
            recorder.pattern_best = regression_model.get_raw_pattern().detach()

    return inner_early_stop_flag


if __name__ == "__main__":
    pass
