import torch
from torch import nn, softmax
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch.utils.data.sampler as sp
from advertorch.attacks import LinfBasicIterativeAttack
import sys
sys.path.append('..')
from utils import classifier_dict, load_model_weights
# import loss

class PseudoTrainer():
    def __init__(self, opt,
        victim, substitute,
        sub_dataset, eval_dataset,
        n_epochs=10):
        self.opt = opt
        self.victim = victim
        self.substitute = substitute
        self.sub_dataset = sub_dataset
        self.eval_dataset = eval_dataset
        self.n_epochs = n_epochs

    def train(self):
        # preparation
        if self.opt.use_gpu:
            self.substitute.cuda()
        dataloader = torch.utils.data.DataLoader(
            self.sub_dataset,
            batch_size = 50,
            shuffle = True,
            num_workers = 4
        )
        substitute_optimizer = torch.optim.Adam(
            self.substitute.parameters(),
            lr = 0.001
        )
        self.substitute.train()
        aug = K.RandomHorizontalFlip()
        ckpt_dir = f'{self.opt.work_dir}checkpoints/'
        # writer = SummaryWriter(f'runs/{self.opt.caption}')

        # train substitute
        sub_ckpt_path = f'{ckpt_dir}substitute_pseudo'

        acc, fidelity = self.evaluate()
        asr, l2_noise, noise_per_pixel = self.adv_evaluate(200)
        print(f'[start] accuracy {acc} | fidelity {fidelity} | ASR {asr} | L2 noise {l2_noise}({noise_per_pixel})')
        

        for epoch in range(self.n_epochs):
            for _, (seed, data, prob) in enumerate(dataloader):
                if self.opt.use_gpu:
                    seed = seed.cuda()
                    data = data.cuda()
                    prob = prob.cuda()
                # train substitute
                self.substitute.zero_grad()
                sub_output = self.substitute(aug(data))
                softmax = nn.LogSoftmax()
                sub_prob = softmax(sub_output)
                substitute_loss_function = torch.nn.KLDivLoss()
                substitute_loss = substitute_loss_function(sub_prob, prob)
                substitute_loss.backward()
                substitute_optimizer.step()

            print(f'[substitute] epoch {epoch} | loss {substitute_loss}')
            if epoch % self.opt.print_freq == 0:
                acc, fidelity = self.evaluate()
                asr, l2_noise, noise_per_pixel = self.adv_evaluate(200)
                print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} | L2 noise {l2_noise}({noise_per_pixel})')
                # writer.add_scalar("loss",substitute_loss,self.loop*self.n_epochs+epoch)
                # writer.add_scalar("accuracy",acc,self.loop*self.n_epochs+epoch)
                # writer.add_scalar("fidelity",fidelity,self.loop*self.n_epochs+epoch)
            torch.save(self.substitute.state_dict(),sub_ckpt_path)

        # writer.flush()
        # writer.close()
        acc, fidelity = self.evaluate()
        asr, l2_noise, noise_per_pixel = self.adv_evaluate(200)
        print(f'[substitute] accuracy {acc} | fidelity {fidelity} | ASR {asr} | L2 noise {l2_noise}({noise_per_pixel})')

        return self.substitute, self.data_gen

    def evaluate(self):
        self.substitute.eval()
        dataloader = self.eval_dataset.test_dataloader()
        accs = 0
        fidelity = 0
        n_samples = 0
        for _, data in enumerate(dataloader):
            imgs = data[0]
            targets = data[1]
            if self.opt.use_gpu:
                imgs = imgs.cuda()
                targets = targets.cuda()
                self.victim.cuda()
                self.substitute.cuda()
            n_samples += targets.shape[0]
            with torch.no_grad():
                outputs = self.substitute(imgs)
                victim_outputs = self.victim(imgs)
                acc = outputs.max(1)[1].eq(targets).float().sum()
                acc = acc.detach().cpu()
                same = victim_outputs.max(1)[1].eq(outputs.max(1)[1]).float().sum()
                same = same.detach().cpu()
            accs += acc
            fidelity += same
        accs /= n_samples
        fidelity /= n_samples
        return accs, fidelity

    def adv_evaluate(self, sample_size):
        if self.opt.source == 'attackgan':
            n_outputs = self.opt.victim_n_classes
            eval_model = classifier_dict[self.opt.eval_model](
                n_outputs = n_outputs
            )
            substitute = load_model_weights(self.substitute,eval_model)
        else:
            substitute = self.substitute
        self.victim.eval()
        substitute.eval()
        data_list = np.random.choice(range(len(self.eval_dataset.test_dataset)),sample_size,replace=False) #[i for i in range(7800,8000)]
        dataloader = torch.utils.data.DataLoader(
            self.eval_dataset.test_dataset, batch_size=50,
            sampler=sp.SubsetRandomSampler(data_list), num_workers=4
        )
        # dataloader = self.eval_dataset.test_dataloader()
        adversary_ghost = LinfBasicIterativeAttack(
            substitute, loss_fn = nn.CrossEntropyLoss(reduction='sum'), eps=self.opt.noise_eps,
            nb_iter=50, eps_iter=self.opt.noise_eps/50, clip_min=-1.0, clip_max=1.0,
            targeted=False)
        attack_success = 0.0
        total = 0.0
        l2_noise = 0.0
        self.victim.eval()
        for data in dataloader:
            inputs, labels = data
            inputs = inputs.cuda()
            labels = labels.cuda()
            with torch.no_grad():
                outputs = self.victim(inputs)
                _, predicted = torch.max(outputs.data,1)
            correct_predict = predicted == labels
            total += correct_predict.float().sum()
            if self.opt.targeted:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, (labels+1)%10)
                with torch.no_grad():
                    outputs = self.victim(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data,1)
                attack_success += (correct_predict*(predicted == (labels+1)%10)).sum()
                for i,correct in enumerate(correct_predict):
                    if correct and predicted[i] == (labels[i]+1)%10:
                        l2_noise += torch.dist(adv_inputs_ghost[i],inputs[i],2)
            else:
                adv_inputs_ghost = adversary_ghost.perturb(inputs, labels)
                with torch.no_grad():
                    outputs = self.victim(adv_inputs_ghost)
                    _, predicted = torch.max(outputs.data,1)
                attack_success += (correct_predict*(predicted != labels)).sum()
                for i,correct in enumerate(correct_predict):
                    if correct and predicted[i] != labels[i]:
                        l2_noise += torch.dist(adv_inputs_ghost[i],inputs[i],2)

            # with torch.no_grad():
            #     victim_labels = self.victim(inputs).max(1)[1]
            #     sub_labels = self.substitute(inputs).max(1)[1]

            # adv_inputs_ghost = adversary_ghost.perturb(inputs, sub_labels)
            # with torch.no_grad():
            #     outputs = self.victim(adv_inputs_ghost)
            #     _, predicted = torch.max(outputs.data,1)
            # # print('victim_labels',victim_labels)
            # # print('sub_labels',sub_labels)
            # # print('predicted',predicted)
            # total += labels.size(0)
            # correct_ghost += (predicted == victim_labels).sum()
        asr = 100. * attack_success.float() / total
        avg_l2_noise = l2_noise / total
        l2_noise_per_pixel = avg_l2_noise / (data[0].shape[1]*data[0].shape[2]*data[0].shape[3])
        return asr,avg_l2_noise,l2_noise_per_pixel