from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset, Subset
import random
import matplotlib.pyplot as plt
import numpy as np
import math
from collections import OrderedDict
import tensorflow as tf
from PIL import Image
import os
import itertools
from typing import List
from torch.cuda.amp import GradScaler, autocast

from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter
import gc

from ffcv.transforms import ToTensor, ToDevice, Squeeze, NormalizeImage, \
    RandomHorizontalFlip, ToTorchImage
from ffcv.fields.rgb_image import CenterCropRGBImageDecoder, \
    RandomResizedCropRGBImageDecoder
from ffcv.fields.basics import IntDecoder
from pathlib import Path
import wandb
from tqdm import tqdm
import heapq
import networkx as nx
import sys

def eq(a, b, *, eps=1e-9):
  return abs(a - b) < eps

def to_chunks(it, size):
  size = int(math.ceil(size))
  it = iter(it)
  return iter(lambda: tuple(itertools.islice(it, size)), ())

device = torch.device("cuda")

def generate_until(gen_f, pred):
  while True:
    res = gen_f()
    if pred(res):
      return res

def gen_order(*, labels, label_map, n_samples, n_negatives):
    n_points = len(labels)
    order = []
    for _ in range(n_samples):
      a = random.randrange(n_points)
      label = labels[a]
      order.append(a)
      order.append(generate_until(
          lambda: random.choice(label_map[label]),
          lambda p: p != a
          )
      )
      for _ in range(n_negatives):
        order.append(generate_until(
            lambda: random.randrange(n_points),
            lambda n: labels[n] != label
            )
        )
    return order

def get_embedding_dataloader(dataset_path, batch_size):
    CIFAR100_MEAN = [255 * x for x in [0.5071, 0.4865, 0.4409]]
    CIFAR100_STD = [255 * x for x in [0.2673, 0.2564, 0.2762]]
    loaders = {}

    label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice(device), Squeeze()]
    image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

    image_pipeline.extend([
        ToTensor(),
        ToDevice(device, non_blocking=True),
        ToTorchImage(),
        Convert(torch.float16),
        transforms.Normalize(CIFAR100_MEAN, CIFAR100_STD),
    ])

    ordering = OrderOption.SEQUENTIAL

    return Loader(dataset_path,
                  batch_size=batch_size,
                  num_workers=2,
                  order=OrderOption.SEQUENTIAL,
                  drop_last=False,
                  pipelines={'image': image_pipeline, 'label': label_pipeline})

def get_outputs(model, batch):
  return model(batch[0].to(device))

def getembeddings(dataset_path):
  model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(device)
  model.fc = nn.Identity()
  model.eval()
  res = []
  train_loader = get_embedding_dataloader(dataset_path, batch_size=1000)
  with torch.no_grad():
    for batch in train_loader:
        with autocast():
          res.append(get_outputs(model, batch).detach().cpu())
  return torch.cat(res).float()

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

def arboricity(g):
    degree = [len(neigh) for neigh in g]
    deleted = [False] * len(g)
    heap = [(d, i) for i, d in enumerate(degree)]
    heapq.heapify(heap)
    res = 0
    while len(heap) > 0:
        _, u = heapq.heappop(heap)
        if deleted[u]:
            continue
        deleted[u] = True
        res = max(res, degree[u])
        for v in g[u]:
            degree[v] -= 1
            heapq.heappush(heap, (degree[v], v))
    return res


def construct_graph(tuples):
    n = 1 + max(max(t) for t in tuples)
    g = [[] for _ in range(n)]
    for t in tuples:
        u = t[0]
        for v in t[1:]:
            g[u].append(v)
            g[v].append(u)
    return g

def gen_samples_with_ground_truth(embeddings, *, n_points, n_samples, n_negatives):
    order = []
    tuples = []
    possible_points = random.sample(range(embeddings.shape[0]), k=n_points)
    for _ in range(n_samples):
        cur = random.sample(possible_points, k=n_negatives + 2)
        while True:
            a, *candidates = cur
            a_emb = embeddings[a]
            similarities = a_emb @ embeddings[candidates].T
            sim_order = torch.argsort(similarities, descending=True)
            if similarities[sim_order[0]] < similarities[sim_order[1]] + 1e-9:
                continue
            best = sim_order[0]
            candidates[0], candidates[best] = candidates[best], candidates[0]
            tuples.append([a] + candidates)
            order += tuples[-1]
            break
    return tuples


def contrastive_samples_with_ground_truth(dataset_path, embeddings, *, n_points, n_samples, n_negatives):
    return gen_samples_with_ground_truth(embeddings, n_points=n_points, n_samples=n_samples, n_negatives=n_negatives)

def get_labels(dataset):
  return torch.tensor(dataset.targets).tolist()


def our_embeddings(queries):
    eprint("start embeddings")
    #print("111")
    G = nx.Graph()
    for x, y, z in queries:
      G.add_edges_from([(x, y), (x, z)])
    #print("222")
    processing_order = list(nx.coloring.strategy_smallest_last(G, None))
# colors = nx.coloring.greedy_color(G, strategy=nx.coloring.strategy_smallest_last)
    colors = nx.coloring.greedy_color(G, strategy=lambda _a, _b: processing_order)
    n_colors = max(colors.values()) + 1
    eprint("colors:", n_colors)

    pairsG = nx.DiGraph()
    for x, y, z in queries:
      a, b = sorted([x, y])
      c, d = sorted([x, z])
      pairsG.add_edge((a, b), (c, d))
    top_order = list(nx.topological_sort(pairsG))


#print(top_order)
    edge_order = {}
    for i, (a, b) in enumerate(reversed(top_order)):
      edge_order[a, b] = i
      edge_order[b, a] = i

    max_prev_neighbors = 0
    was = set()
    for u in processing_order:
      was.add(u)
      max_prev_neighbors = max(max_prev_neighbors, sum(1 for v in G.neighbors(u) if v in was))
      #if sum(1 for v in G.neighbors(u) if v in was) > 10:
      #    print(u, list(v for v in G.neighbors(u) if v in was))
    max_prev_neighbors += 1

    eprint("max_prev_neighbors:", max_prev_neighbors)
    inner_embeddings = {}
    for u in processing_order:
      A = []
      b = []
      for v in G.neighbors(u):
        if v in inner_embeddings:
            A.append(inner_embeddings[v])
            b.append(edge_order[u, v])
      while len(b) < max_prev_neighbors:
          b.append(1)
          vec = (np.random.randint(2, size=max_prev_neighbors) * 2 - 1) / np.sqrt(max_prev_neighbors)
          A.append(vec + np.random.normal(size=max_prev_neighbors) / 10)
          #A.append(np.random.normal(size=max_prev_neighbors) / np.sqrt(max_prev_neighbors))
      e = np.linalg.lstsq(A, b, rcond=None)[0]
      inner_embeddings[u] = e

    #for (u, v), d in edge_order.items():
    #  assert eq(d, np.dot(inner_embeddings[u], inner_embeddings[v]), eps=0.5), f"{d} {np.dot(inner_embeddings[u], inner_embeddings[v])}"

    def basis(pos, dim):
      res = np.zeros(dim)
      res[pos] = 1
      return res

    max_norm = max(np.linalg.norm(e)
                  for e in inner_embeddings.values())
    embeddings = {}

    for u in G.nodes():
      dif = np.sqrt(1e-13 + max_norm ** 2 - np.linalg.norm(inner_embeddings[u]) ** 2)
      embeddings[u] = np.concatenate([inner_embeddings[u], dif * basis(colors[u], n_colors)])

    eprint("Finished embeddings")

    #for u in G.nodes():
    #    assert eq(max_norm, np.linalg.norm(embeddings[u]), eps=1e-3), f"{max_norm} {np.linalg.norm(embeddings[u])}"

    #for (u, v), d in edge_order.items():
    #    assert eq(d, np.dot(embeddings[u], embeddings[v]), eps=0.5), f"{d} {np.dot(embeddings[u], embeddings[v])}"

    return embeddings

if __name__ == "__main__":
    cifar100_train_embeddings = getembeddings("/tmp/cifar100_train.beton")

    n_runs = 10
    n_negatives = 1
    #for n_points in [10, 100, 1000, 10000, 50000]:
    for _ in range(n_runs):
        for n_samples in [100, 1000, 10 ** 4, 10 ** 5, 10 ** 6, 10 ** 7]:
            for n_points in [128, 256, 512, 1024, 2048]:
                n_vals = n_samples * (2 + n_negatives)
                queries = contrastive_samples_with_ground_truth("/tmp/cifar100_train.beton", cifar100_train_embeddings, n_points=n_points, n_samples=n_samples, n_negatives=n_negatives)
                embeddings = our_embeddings(queries)
                print("m:", n_samples, "n:", n_points, "dim:", len(embeddings[queries[0][0]]), flush=True)
                eprint("m:", n_samples, "n:", n_points, "dim:", len(embeddings[queries[0][0]]))
                for (x, y, z) in queries:
                    assert np.linalg.norm(embeddings[x] - embeddings[y]) < np.linalg.norm(embeddings[x] - embeddings[z])

