from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, Subset
import random
import matplotlib.pyplot as plt
import numpy as np
import math
from collections import OrderedDict
import tensorflow as tf
from PIL import Image
import os
import itertools
from typing import List
from torch.cuda.amp import GradScaler, autocast

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter
import gc

from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \
    RandomHorizontalFlip, ToTorchImage
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.fields.basics import IntDecoder
import wandb
import heapq


debug = True
debug = False


def to_chunks(it, size):
    size = int(math.ceil(size))
    it = iter(it)
    return iter(lambda: tuple(itertools.islice(it, size)), ())


device = torch.device("cuda:1")


def generate_until(gen_f, pred):
    while True:
      res = gen_f()
      if pred(res):
        return res


def get_embedding_dataloader(dataset_path, order, batch_size):
    CIFAR_MEAN = [125.307, 122.961, 113.8575]
    CIFAR_STD = [51.5865, 50.847, 51.255]

    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

    image_pipeline.extend([
        ToTensor(),
        ToDevice(device, non_blocking=True),
        ToTorchImage(),
        Convert(torch.float16),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    return Loader(dataset_path,
                  batch_size=batch_size,
                  num_workers=2,
                  order=OrderOption.SEQUENTIAL,
                  indices=order,
                  drop_last=False,
                  pipelines={'image': image_pipeline, 'label': label_pipeline})


def ground_truth_model():
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(device)
    model.fc = nn.Identity()
    model.eval()
    return model


def get_embeddings(model, data_loader):
    res = []
    with torch.no_grad():
        for batch in data_loader:
            with autocast():
                res.append(get_outputs(model, batch).detach().cpu())
    return torch.cat(res).float()


def get_neighbor_order(embeddings):
    dists = torch.cdist(embeddings, embeddings)
    res = dists.argsort(dim=1)[:, 1:].numpy()  # [:, 1:] removes itself
    if debug:
        print(dists)
        print(res)
    return res


def gen_order_with_ground_truth(*, neighbor_order, indices, n, k):
    tuples = []
    for u in range(n):
        for v in range(k):
            for w in range(v, n - 1):
                tuples.append((indices[u], indices[neighbor_order[u, v]], indices[neighbor_order[u, w]]))
    if debug:
        print(indices)
        print(tuples)
    random.shuffle(tuples)
    return [x
            for t in tuples
            for x in t]


def knn_dataloader_with_ground_truth(*, dataset_path, neighbor_order, indices, n, k, batch_size):
    order = gen_order_with_ground_truth(neighbor_order=neighbor_order, indices=indices, n=n, k=k)

    CIFAR_MEAN = [125.307, 122.961, 113.8575]
    CIFAR_STD = [51.5865, 50.847, 51.255]

    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice(torch.device(device)), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

    image_pipeline.extend([
        ToTensor(),
        ToDevice(torch.device(device), non_blocking=True),
        ToTorchImage(),
        Convert(torch.float16),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    batch_size = batch_size - batch_size % 3

    return Loader(dataset_path,
                  batch_size=batch_size,
                  num_workers=2,
                  order=OrderOption.SEQUENTIAL,
                  indices=order,
                  drop_last=False,
                  pipelines={'image': image_pipeline, 'label': label_pipeline}),\
        len(order) * 3


def get_batch_length(batch):
    return len(batch[1])

def get_outputs(model, batch):
    return model(batch[0].to(device))

loss_f = nn.TripletMarginLoss()

def contrastive_loss_acc(outputs, n_negatives):
    assert n_negatives == 1
    assert outputs.shape[0] % 3 == 0
    tuples_sep = outputs.reshape([outputs.shape[0] // 3, 3] + list(outputs.shape[1:]))
    assert len(tuples_sep.shape) == 3
    assert tuples_sep.shape[1] == 3
    anchor = tuples_sep[:, 0]
    positive = tuples_sep[:, 1]
    negative = tuples_sep[:, 2]
    loss = loss_f(anchor, positive, negative)
    acc = torch.mean((torch.linalg.vector_norm(anchor - positive, dim=1)
                      < torch.linalg.vector_norm(anchor - negative, dim=1)
                      ).float())
    return loss, acc


def train(model, train_loader, optimizer, epoch, n_vals, loss_acc_f):
    model.train()
    n_batches = len(train_loader)
    # progress_bar = tqdm(train_loader, position=0, leave=True, miniters=10)
    losses = []
    # for batch in progress_bar:
    total_loss = 0
    total_acc = 0

    progress_bar = tqdm(train_loader, position=0, leave=True, miniters=10)
    n_points = 0
    for batch in progress_bar:
        optimizer.zero_grad()
        with autocast():
            outputs = get_outputs(model, batch)
        cur_points = outputs.shape[0]
        n_vals -= cur_points
        n_points += cur_points
        if n_vals < 0:
            outputs = outputs[:n_vals]
        loss, acc = loss_acc_f(outputs)
        total_loss += loss.detach().item() * cur_points
        total_acc += acc * cur_points
        loss.backward()
        optimizer.step()
        wandb.log({"train": {
            "loss": total_loss / n_points,
            "acc": 100 * total_acc / n_points,
        }})
        progress_bar.set_description(f'Train Epoch: {epoch}  Loss: {total_loss / n_points:.6f} Accuracy: {100 * total_acc / n_points}%', refresh=False)

        if n_vals < 0:
            break


def test(model, emb_dataloader, true_neighbor_order, *, n, k):
    our_embeddings = get_embeddings(model, emb_dataloader)
    assert our_embeddings.shape[0] == n, f"{our_embeddings.shape[0]} {n}"
    our_neighbor_order = get_neighbor_order(our_embeddings)
    pos = np.zeros((n, n))
    for u in range(n):
        for j in range(n - 1):
            pos[u, our_neighbor_order[u, j]] = j
    total_loss = 0
    for u in range(n):
        for j, v in enumerate(true_neighbor_order[u, :k]):
            total_loss += abs(j - pos[u, v])
            if debug:
                print(v, pos[u, v], abs(j - pos[u, v]))

    total_loss /= n * k

    print(f'Test loss: {total_loss:.4f}')
    return total_loss


if __name__ == "__main__":
    assert False, "We use wandb for logging. Please specify your wandb key below"
    wandb.login(key="")
    dataset_path = "/tmp/cifar_train.beton"
    cifar10_train_embeddings = get_embeddings(ground_truth_model(),
                                              get_embedding_dataloader(dataset_path, order=None, batch_size=1000))

    n_runs = 10
    epochs = 6
    lr = 0.01
    dim = 128
    n_negatives = 1
    #for n in [10, 100, 1000]:
    #    for k in [1, 2, 4, 8, 16]:
    for n in [10, 100, 1000]:
        for k in [1, 2, 4, 8, 16, 32, 64]:
        #for k in [32, 64]:
            if k >= n:
                continue
            for _ in range(n_runs):
                wandb.init(
                    project="knn_cifar10", entity="", reinit=True, name=f"{dim=} {lr=} {n=} {k=}",
                    config={"dim": dim, "lr": lr, "n": n, "k": k}
                )
                indices = random.sample(range(cifar10_train_embeddings.shape[0]), k=n)
                neighbor_order = get_neighbor_order(cifar10_train_embeddings[indices])

                model = models.resnet18()
                model.fc = nn.Linear(512, dim, bias=False)
                model = model.to(device)
                optimizer = optim.Adadelta(model.parameters(), lr=lr)

                emb_dataloader = get_embedding_dataloader(dataset_path, order=indices, batch_size=1000)
                train_loader, n_vals = knn_dataloader_with_ground_truth(dataset_path=dataset_path, neighbor_order=neighbor_order, indices=indices, n=n, k=k, batch_size=500)
                for epoch in range(epochs):
                    if epoch % 5 == 0:
                        if epoch == epochs - 1:
                            our_embeddings = get_embeddings(model, emb_dataloader)
                            our_neighbor_order = get_neighbor_order(our_embeddings)
                            if debug:
                                print(neighbor_order)
                                print()
                                print(our_neighbor_order)
                        test_loss = test(model, emb_dataloader, neighbor_order, n=n, k=k)
                        wandb.log({"test": {"moving_distance": test_loss}}, commit=False)
                    train(model, train_loader, optimizer, epoch, n_vals, lambda o: contrastive_loss_acc(o, n_negatives))
                print(f"{dim=} {n=} {k=} Final moving distance: {test_loss}")
                print("-----------------------------------------------")
                wandb.finish()

