import torch
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.nn as nn
from PIL import Image
import os
from tqdm.auto import tqdm
import numpy as np
from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from matplotlib import image as mpimg
from utils import *
from attacks import *
from dataset_info.imagenet import imagenet_classes
import seaborn as sns
from dataset_info.country_list import original_country_list, translated_country_list
import torch.optim as optim
from pytorch_lightning import LightningModule, Trainer
from torchmetrics.functional.classification.accuracy import accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pathlib import Path
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.swa_utils import AveragedModel, update_bn
from pytorch_lightning.plugins import DDPPlugin

imagenet_folder = "/home/ubuntu/datasets/imagenet/imagenet_images/"
output_folder = "results"
anchor_folder = "/home/ubuntu/datasets/imagenet/anchor_adv_images/"
# step size in generating attacks
stepsize = 0.1
# if PGD attack
PGD = False
size_of_batch = 16

class LitImagenetModel(LightningModule):
    def __init__(self, celoss, train_set, val_set, test_set, algorithm_2, sampler):
        super().__init__()
#         self.model = models.resnet50(pretrained=True).train()
        self.model = models.resnet18(pretrained = True).train()
        for param in self.model.parameters():
            param.requires_grad=True
        self.celoss = celoss
        self.hparams.batch_size = size_of_batch
        self.hparams.lr = 0.1
        self.initial_training_set = train_set
        self.training_set = train_set
        self.val_set = val_set
        self.test_set = test_set
        self.size = len(train_set)
        self.best_val_acc = 0
        self.algorithm_2 = algorithm_2
        self.sampler = sampler

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        with torch.no_grad():
            outputs = self.model(norm(x))
        return outputs

    def load_model(self, path):
        old_dict = torch.load(path)['state_dict']
        model_dict = self.model.state_dict()

        # 1. filter out unnecessary keys
        new_dict = {''.join(k.split('.', 1)[1]): v for k, v in old_dict.items()}
        # 2. overwrite entries in the existing state dict
        model_dict.update(new_dict)
        # 3. load the new state dict
        self.model.load_state_dict(model_dict)


    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        image, label, name, _, _ = batch

        # if in algorithm 2, we generate attacks
        if self.algorithm_2:
            if PGD:
                attack = PGD_attack(self.model, self.celoss, stepsize, image, label).detach()
            else:
                attack = PieAPP_GD_attack_batch_useanchors(self.model, self.celoss, stepsize, image, label, name, anchor_folder).detach()
                # if not using anchors use this:
                attack = PieAPP_GD_attack_batch(self.model, self.celoss, stepsize, image, label, "", name, conf=0.0, lam=1.0).detach()

            for j in range(attack.shape[0]):
                if name[j] in anchors:
                    x_np=transforms.ToPILImage()((attack[j]-image[j]).detach().cpu())
                    x_np.save(os.path.join(anchor_folder,name[j]+'.png'))

                if not os.path.isfile(adv_imagenet_folder+imagenet_classes[label[j].item()]):
                    Path(adv_imagenet_folder+imagenet_classes[label[j].item()]).mkdir(parents=True, exist_ok=True)
                x_np=transforms.ToPILImage()(attack[j].detach().cpu())
                # save the attack image with its current epoch and batch number to compute the weight
                path = os.path.join(adv_imagenet_folder+imagenet_classes[label[j].item()],name[j]+':'+str(self.current_epoch)+':'+str(batch_idx)+'.png')
                x_np.save(path)

        # not in algorithm 2, we train on the dataset, no attacks
        outputs = self.model(norm(image))

        loss = self.celoss(outputs, label)
        # Logging to TensorBoard by default
        self.log("train_loss", loss.detach(), prog_bar=True)
        logs={"train_loss": loss.detach()}

        return {
            "loss": loss,
            "log": logs }

    def validation_step(self, batch, batch_idx):
        x, y, name, _, _ = batch
        logits = self.model(norm(x))
        loss = self.celoss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        # Calling self.log will surface up scalars in TensorBoard
        self.log('val_loss', loss.detach(), prog_bar=True)
        self.log('val_acc', acc.detach(), prog_bar=True)

        log={"val_loss": loss.detach(),
            'val_acc': acc.detach() }

        return {
            "loss": loss,
            "acc": acc,
            "log": log }

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
        optimizer.zero_grad(set_to_none=True)

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.lr,
            momentum=0.9,
            weight_decay=5e-4,
        )
        steps_per_epoch = len(self.training_set) // self.hparams.batch_size
        scheduler_dict = {
            'scheduler': OneCycleLR(
                optimizer,
                0.1,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=steps_per_epoch,
            ),
            'interval': 'step',
        }
        return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}

    def train_dataloader(self):
        if self.sampler:
            return DataLoader(self.training_set, batch_size=self.hparams.batch_size, sampler = self.sampler, num_workers=8, pin_memory=True)
        return DataLoader(self.training_set, batch_size=self.hparams.batch_size, shuffle=True, num_workers=32, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=64, shuffle=False, num_workers=32, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.val_set, batch_size=64, shuffle=False, num_workers=32, pin_memory=True)

    def on_train_epoch_start(self):
        # after one epoch of training, we compute the weight and do a weighted sampling
        if self.current_epoch>0 and self.algorithm_2:
            adv_dataset = custom_imagenet_dataset(adv_imagenet_folder, transform = data_transforms['train'], country= False)
            self.training_set = torch.utils.data.ConcatDataset([adv_dataset, self.initial_training_set])
            weights= []
            print("trainset size", len(self.training_set))
            dd = DataLoader(self.training_set, batch_size=128, shuffle=False, num_workers=32, pin_memory=True)
            for (_, _, _, epoch, batch) in tqdm(dd):
                # weights are initialized as all 1
                weight = torch.ones(batch.shape[0])
                # update the weights for images this batch
                # weights in Algorithm2: P_i= (k−1)N+i, here i is the batch number, k is epoch number, N is size
                weight[epoch!=0] =(self.hparams.batch_size*batch[epoch!=0]+(epoch[epoch!=0]-1)*self.size).float()
                weights.append(weight)
            weights = torch.cat(weights)
            self.sampler = WeightedRandomSampler(weights, self.size, replacement=True, generator=None)
            self.train_dataloader()

if __name__ == '__main__':
    imagenet_dataset = custom_imagenet_dataset(imagenet_folder, transform = None, country= False)
    size = len(imagenet_dataset)

    # we fix the split when we train the model and test the model on fairness
    split_length= [int(size*0.6), int(size*0.2)+int(size*0.6)]
    indices = np.random.RandomState(seed=12).permutation(size)
    training_idx, val_idx, test_idx = indices[:split_length[0]], indices[split_length[0]:split_length[1]],indices[split_length[1]:]
    train_set = Subset(imagenet_dataset, training_idx)
    val_set = Subset(imagenet_dataset, val_idx)
    test_set = Subset(imagenet_dataset, test_idx)

    initial_training_set = Dataset_from_subset(train_set, data_transforms['train'])
    val_set = Dataset_from_subset(val_set, data_transforms['val'])
    test_set = Dataset_from_subset(test_set, data_transforms['val'])
    initial_size = len(initial_training_set)

    criterion = nn.CrossEntropyLoss()
    algorithm_2 = True
    adv_imagenet_folder = "/home/ubuntu/datasets/imagenet/adv_imagenet_images_resnet/"
    Path(adv_imagenet_folder).mkdir(parents=True, exist_ok=True)
    # number of models to sample
    number_runs=50

# if not using kmeans to speed up, comment out from here
    anchors = torch.load("images_kmeans_results/anchor_set.pt")
    # generate attacks for the anchors
    device = "cuda"
    model = models.resnet50(pretrained=True).train().to(device)
    for i in range(1000):
        if os.path.isfile("./images_kmeans_results/"+str(i)+".pt"):
            image_centers = set(torch.load("./images_kmeans_results/"+str(i)+".pt").values())
            anchor_set=list(image_centers)
            class_name = imagenet_classes[i]
            label = torch.tensor(i).to(device).unsqueeze(0)
            for j in range(len(anchor_set)):
                if os.path.isfile(os.path.join(anchor_folder,anchor_set[j]+'.png')):
                    continue
                # a mistake in generating anchor names, just ignore it
                if anchor_set[j]=="St":
                    continue
                try:
                    img = data_transforms['val'](Image.open(os.path.join(imagenet_folder, class_name, anchor_set[j]+".jpg"))).to(device)
                except FileNotFoundError:
                    try:
                        img = data_transforms['val'](Image.open(os.path.join(imagenet_folder, class_name, anchor_set[j]+".gif"))).to(device)
                    except FileNotFoundError:
                        img = data_transforms['val'](Image.open(os.path.join(imagenet_folder, class_name, anchor_set[j]+".png"))).to(device)
                if img.shape[0]==1:
                    img = img.repeat(3, 1, 1)
                img = img[:3].unsqueeze(0)
                attack = PieAPP_GD_attack_batch(model, criterion, stepsize, img, label, "", "", 0).detach()
                x_np=transforms.ToPILImage()((attack-img).squeeze().detach().cpu())
                x_np.save(os.path.join(anchor_folder,anchor_set[j]+'.png'))
# if not using kmeans to speed up, comment out till here

    if algorithm_2:
        training_set = initial_training_set
        sampler = None
        number_runs = 1
    else:
        adv_dataset = custom_imagenet_dataset(adv_imagenet_folder, transform = data_transforms['train'], country= False)
        training_set = torch.utils.data.ConcatDataset([adv_dataset, initial_training_set])
        dd = DataLoader(training_set, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)
        weights= []
        for (_, _, _, epoch, batch) in tqdm(dd):
            weight = torch.ones(batch.shape[0])
            weight[epoch!=0] =(size_of_batch*batch[epoch!=0]+(epoch[epoch!=0]-1)*initial_size).float()
            weights.append(weight)
        weights = torch.cat(weights)
        torch.save(weights, "adv_imagenet_images_resnet_weights.pt")
        sampler = WeightedRandomSampler(weights, 2*initial_size, replacement=True, generator=None)

    print("train set size:", len(training_set))

    for v in range(number_runs):
        if algorithm_2:
            model_path = "imagenet_models_resnet/alg_2/"
        else:
            model_path = "imagenet_models_resnet/DRO_sample/version_"+str(v)+"/"

        Path(model_path).mkdir(parents=True, exist_ok=True)
        # Init our model
        imagenet_model = LitImagenetModel(criterion, training_set, val_set, test_set, algorithm_2, sampler = sampler)

        # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
        checkpoint_callback = ModelCheckpoint(
            monitor="val_acc",
            dirpath= model_path,
            filename="imagenet_{epoch:02d}_{val_acc:.2f}",
            save_weights_only = True,
            save_top_k=3,
            mode="max",
        )

        #     Initialize a trainer
        trainer = Trainer(
            gpus=torch.cuda.device_count(),
            max_epochs=1,
            accelerator='ddp',
            precision=16,
            val_check_interval=0.25,
            stochastic_weight_avg=True,
            accumulate_grad_batches=4,
            num_sanity_val_steps=0,
            auto_scale_batch_size= "power",
            log_every_n_steps=10,
            callbacks=[checkpoint_callback],
            replace_sampler_ddp = False,
        )
        # Train the model ⚡
        tic = time.time()
        trainer.fit(imagenet_model)
        toc = time.time()
        print(f"Training in {toc - tic:0.4f} seconds.")
