import torch
from torch import nn, softmax
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data.sampler as sp
import torch.distributions as d
import numpy as np
from advertorch.attacks import LinfBasicIterativeAttack
from torchvision import transforms
import classifier
import loss
import sys
import kd_loss
sys.path.append('..')
import setup
from data import *
from utils import *
# import loss
import time
import os

class GANTrainer():
    def __init__(self, opt, substitute, data_gen,
                 train_dataset, eval_dataset, n_epochs,
                 source, strategy='every',
                 model_name='fusiongan_300'):
        self.opt = opt
        self.substitute = substitute
        self.data_gen = data_gen
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.source = source
        self.strategy = strategy
        self.n_epochs = n_epochs
        self.model_name = model_name
        self.beta = 0.2
        s = 1
        color_jitter = transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.augmentation = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),  # with 0.5 probability
                # transforms.RandomRotation(degrees=(0, 90)),
                # transforms.RandomVerticalFlip(),  # with 0.5 probability
                # transforms.RandomApply([color_jitter], p=0.8),
                # transforms.RandomGrayscale(p=0.2),
            ]
        )

    def train(self):
        model_exists = False
        ckpt_path = '%scheckpoints/%s_state_dict' % (self.opt.work_dir, self.model_name)
        if os.path.exists(ckpt_path):
            self.data_gen.load_state_dict(torch.load(ckpt_path))
            model_exists = True

        training_was_in_progress = False
        root_optimizer_ckpt_path = 'optimizer_for_%s_state_dict' % self.model_name
        optimizer_ckpt_path = root_optimizer_ckpt_path
        for filename in os.listdir('%scheckpoints' % self.opt.work_dir):
            if optimizer_ckpt_path in filename:
                training_was_in_progress = True
                optimizer_ckpt_path = filename

        if model_exists and not training_was_in_progress:
            print('Opimizer state lost. Model loaded.')
            return self.data_gen

        starting_epoch_n = 0
        if training_was_in_progress:
            starting_epoch_n = int(optimizer_ckpt_path.split('_')[-1])

        # preparation
        if self.opt.use_gpu:
            self.substitute.cuda()
            self.data_gen.cuda()
        train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=50,
            shuffle=True,
            num_workers=4
        )

        data_gen_optimizer = torch.optim.Adam(
            self.data_gen.parameters(),
            lr=0.001
        )
        self.substitute.train()
        self.data_gen.train()

        ckpt_dir = f'{self.opt.work_dir}checkpoints/'
        gen_ckpt_path = f'{ckpt_dir}pre_fusiongan'

        # acc, fidelity, kd_loss = self.evaluate()
        # asr, l2_noise, noise_per_pixel = self.adv_evaluate(200)
        # print(f'[start] accuracy {acc} | fidelity {fidelity} | ASR {asr} ' +
        #       f'| KD loss {kd_loss} | L2 noise {l2_noise}({noise_per_pixel})')

        train_epoch = self.n_epochs

        max_certainty = 0.0
        max_certainty_epoch = 0
        max_noise = 0.0
        max_noise_epoch = 0
        for epoch in range(starting_epoch_n+1, train_epoch+1):
            certainty = 0.0
            num_batch = 0.0
            for _, (clean_imgs, _, _, _) in enumerate(train_dataloader):
                num_batch += 1
                clean_imgs = tuple([img.cuda() for img in clean_imgs])
                # concat fusion images
                # clean_imgs = []
                # for j in range(self.opt.n_fuse):
                #     fuse_list = []
                #     for item in batch_clean_imgs:
                #         fuse_list.append(item[j])
                #     clean_imgs.append(torch.cat(tuple(fuse_list), 0).cuda())
                # clean_imgs = tuple(clean_imgs)
                log_softmax = nn.LogSoftmax()
                softmax = nn.Softmax()
                new_perturbed_img = self.data_gen(clean_imgs)
                # new_perturbed_img = nn.parallel.data_parallel(self.data_gen, clean_imgs)
                new_perturbed_sub_output = nn.parallel.data_parallel(self.substitute, new_perturbed_img)
                new_perturbed_sub_prob = log_softmax(new_perturbed_sub_output)

                # define loss functions
                mse_loss = nn.MSELoss()
                # kldiv_loss = nn.KLDivLoss()
                ce_loss = nn.CrossEntropyLoss()
                # div_func = loss.DivLoss()

                # train generator
                self.data_gen.zero_grad()
                noise_loss = 0.0
                for i in range(self.opt.n_fuse):
                    noise_loss += mse_loss(new_perturbed_img, clean_imgs[i]) * self.beta ** i
                new_perturbed_argmax = new_perturbed_sub_prob.argmax(1).cuda()
                uncertain_loss = torch.exp(-1 * ce_loss(new_perturbed_sub_prob, new_perturbed_argmax))
                gen_loss = self.opt.noise_weight * noise_loss  # + self.opt.uncertain_weight * uncertain_loss
                gen_loss.backward(retain_graph=True)
                data_gen_optimizer.step()

                certain_probs = softmax(new_perturbed_sub_output)
                max_certain_prob = torch.mean(certain_probs.max(1)[0])
                certainty += max_certain_prob.detach().cpu()

                    # print(f'[substitute] epoch {epoch} | loss {substitute_loss}')
                    # if not (self.opt.continue_train and self.loop == self.opt.n_loop - 1):
                    #     # print(f'[data generator] epoch {epoch} | loss {gen_loss} | {noise_loss},{adv_loss},{div_loss}')
                    #     print(f'[data generator] epoch {epoch} | loss {gen_loss} | {noise_loss},{adv_loss}')
            print(f'[data generator] epoch {epoch} | loss {gen_loss}')
            torch.save(self.data_gen.state_dict(), gen_ckpt_path)

            if self.opt.loss_threshhold:
                if substitute_loss < self.opt.loss_threshhold:
                    break
            avg_certainty = certainty / num_batch
            if avg_certainty > max_certainty:
                max_certainty = avg_certainty
                max_certainty_epoch = epoch
            if self.opt.same_certainty_epoch:
                if epoch - max_certainty_epoch >= self.opt.same_certainty_epoch:
                    break
            if self.opt.noise_fall_epoch:
                if epoch - max_noise_epoch >= self.opt.noise_fall_epoch:
                    break
            if self.opt.certainty_threshhold:
                if avg_certainty > self.opt.certainty_threshhold:
                    break
            torch.save(self.data_gen.state_dict(), ckpt_path)
            new_checkpoint_path = '%s_%d' % (root_optimizer_ckpt_path, epoch)
            torch.save(data_gen_optimizer.state_dict(), '%scheckpoints/%s' % (self.opt.work_dir, new_checkpoint_path))
            if os.path.exists('%scheckpoints/%s' % (self.opt.work_dir, optimizer_ckpt_path)):
                os.unlink('%scheckpoints/%s' % (self.opt.work_dir, optimizer_ckpt_path))
            optimizer_ckpt_path = new_checkpoint_path

        return self.data_gen
        
    def test(self):
        pass