import argparse
import csv
import os
import time
from types import SimpleNamespace

from PIL import Image
import torch

from models import get_random_crop, load_model, to_tensor, to_pil
from setup_env import initialize_DDP


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--launcher',
                        type=str,
                        default='pytorch',
                        help='should be either `slurm` or `pytorch`')
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    parser.add_argument('--eps', type=int, default=16)

    return parser.parse_args()


def read_csv(path):
    with open(path) as f:
        contents = csv.reader(f, delimiter=',')
        contents = [i for i in contents]
    return contents[1:]


def read_dataset(path="./data/nips2017_adv_dev/"):
    categories = read_csv(f"{path}/categories.csv")
    lookup = {}
    for category_id, category_text in categories:
        lookup[category_id] = category_text

    image_list = read_csv(f"{path}/images.csv")
    output = []
    for i in image_list:
        data = SimpleNamespace()
        data.image_id = i[0]
        data.image_path = f"{path}/images/{i[0]}.png"
        data.gt = int(i[6])
        data.target = int(i[7])
        data.gt_text = lookup[i[6]]
        data.target_text = lookup[i[7]]

        output.append(data)

    return output



def main():

    args = get_args()
    rank, local_rank, world_size = initialize_DDP(args.launcher)

    dataset = read_dataset()

    dataset = [j for i, j in enumerate(dataset) if i % world_size == rank]

    model_list = [
        "open_clip/ViT-H-14-378-quickgelu",
        "open_clip/ViT-H-14-quickgelu",
        "open_clip/ViT-SO400M-14-SigLIP-384",
        "open_clip/ViT-SO400M-14-SigLIP",
        "open_clip/ViT-L-16-SigLIP-384",
        "open_clip/ViT-bigG-14",
        "open_clip/ViT-H-14-CLIPA-336",
        "open_clip/ViT-H-14-quickgelu",
    ]

    """
    model_list = [
        "open_clip/ViT-H-14-378-quickgelu",
        "open_clip/ViT-SO400M-14-SigLIP-384",
        "open_clip/ViT-bigG-14",
        "open_clip/ViT-H-14-CLIPA-336",
    ]


    model_list = [
        "open_clip/ViT-H-14-378-quickgelu",
        "open_clip/ViT-SO400M-14-SigLIP-384",
    ]

    model_list = ["open_clip/ViT-H-14-378-quickgelu"]
    """

    models = []
    for model_id in model_list:
        models.append(load_model(model_id=model_id,
                                 device="cuda"))
        print(f"Load {len(models)} / {len(model_list)} models")

    os.makedirs("results", exist_ok=True)
    os.makedirs(f"results/norm_{args.eps}", exist_ok=True)

    start_time = time.time()
    for data_idx, data in enumerate(dataset):
        image = to_tensor(data.image_path).to("cuda")
        _, _, H, W = image.shape
        inputs = {
            "target_text": f"a photo of {data.target_text}",
            "untarget_text": f"a photo of {data.gt_text}",
        }

        adv = torch.zeros_like(image, requires_grad=True)
        optmizer = torch.optim.Adam([adv], lr=10)

        for step in range(500):
            optmizer.zero_grad()
            adv_image = image + adv
            total_loss = 0

            for model in models:
                if torch.randn(1).item() > 0:
                    slice_h, slice_w = get_random_crop(H, W)
                    _adv_image = adv_image[..., slice_h, slice_w]
                else:
                    _adv_image = adv_image
                inputs["image"] = _adv_image

                loss = model.compute_loss(**inputs)
                loss.backward()
                total_loss += loss.item()

            optmizer.step()

            _adv = adv.data.clamp(-args.eps, args.eps)
            _adv = (image + _adv).clamp(0, 255) - image
            adv.data.copy_(_adv)

            if total_loss < 0.1:
                break

        adv_image = (image + adv).round()
        pil_img = to_pil(adv_image)
        pil_img.save(f"results/norm_{args.eps}/{data.image_id}.png")
        speed = (time.time() - start_time) / (data_idx + 1)
        eta = speed * (len(dataset) - data_idx -1) / 3600
        print(f"Processed {data_idx + 1} / {len(dataset)} samples, "
              f"loss: {total_loss: .3f}, "
              f"eta: {eta: .2f} hours.")


if __name__ == '__main__':
    main()
