import pickle
from pathlib import Path

import datasets
import jax
import jax.numpy as jnp
import linprobe
import numpy as onp
import pyt
from jax import grad, jit, value_and_grad
from jax.example_libraries import optimizers
from matplotlib import pyplot as plt
from models.flax_models import LeNetLarge
from utils import load


def load_params_at_epoch(run, epoch, opt):
    checkpoint = f"/homes/ag2198/euclid-scratch/run-outputs/outputs/{run}/checkpoints/checkpoint_{epoch:08d}.pkl"
    checkpoint_state = load(checkpoint)
    start_epoch, step, il_state, batch_stats, opt_state = checkpoint_state
    assert not batch_stats
    opt_init, opt_update, get_params = opt(1)
    opt_state = optimizers.pack_optimizer_state(opt_state)
    params = get_params(opt_state)
    return params

gen_patches = False
gen_filters = True
gen_linprobe = True

EXPERIMENT = "data-aug"
# EXPERIMENT = "adam"
LOAD = True
# LOAD = False
EXTRACT_LAYER = 1
# model = LeNetLarge(num_outputs=10)
model = LeNetLarge(num_outputs=10, extract_layer=EXTRACT_LAYER)
forward = lambda x, p: jax.nn.softmax(model.apply({"params": p}, x))
loss_fn = jit(lambda x, p, feature_index: forward(x, p)[0][feature_index])


@jit
def loss_fn(x, p, feature_index):
    activation = model.apply({"params": p}, x)
    # We avoid border artifacts by only involving non-border pixels in the loss.
    filter_activation = activation[:, 2:-2, 2:-2, feature_index]
    return filter_activation.mean()


def get_image(params, lr, N, key, feature_index):
    image = jax.random.uniform(key, shape=(1, 32, 32, 3)) * 64 + 128
    init_loss = loss_fn(image, params, feature_index)
    for i in range(N):
        loss, grads = value_and_grad(loss_fn)(image, params, feature_index)
        ngrads = grads / pyt.l2norm(grads)
        image += lr * ngrads
        image = jnp.clip(image, a_min=0, a_max=255)
    return image, (init_loss, loss)

def get_image_patches(params, dataset, model, extract_layer, num_features=50):
    best = [(-float("inf"), None) for i in range(num_features)]
    for image, label in dataset.train_loader:
        orig_output = onp.array(model.apply({"params": params}, image))
        # Cut off edges of image so we get a full square that activates as much as possible
        b = {1: 2, 2: 4}[extract_layer]
        output = orig_output[:, b:-b, b:-b]
        output_maxes = output.max(axis=(0,1,2))
        for feature_index in range(num_features):
            if output_maxes[feature_index] > best[feature_index][0]:
                i, r, c = onp.unravel_index(output[:,:,:,feature_index].argmax(), output.shape[:-1])
                r += b
                c += b
                assert orig_output[i, r, c, feature_index] == output_maxes[feature_index]
                if extract_layer == 1:
                    x = 2
                elif extract_layer == 2:
                    x = 7
                    r *= 2
                    c *= 2
                patch = onp.array(image[i, r-x:r+x+1, c-x:c+x+1])
                best[feature_index] = (output_maxes[feature_index], patch)
    return best

def get_features(params, model, dataset, num_classes=10):
    labels = onp.asarray(jax.nn.one_hot(dataset.train_targets, num_classes=num_classes))
    print("labels done")
    features_list = []
    for images, _ in dataset.train_loader:
        features_x = model.apply({"params": params}, images)
        features_list.append(onp.asarray(features_x))
    features = onp.concatenate(features_list)
    print("features done")
    features_flat = features.reshape(features.shape[0], -1)
    print("features converted")
    data = (features_flat, labels)
    return data


def train_linprobe(params, dataset, epochs=100):
    opt_state, opt_update, get_params = None, None, None
    step = 0
    for i in range(epochs):
        for images, labels in dataset.train_loader:
            features_x = model.apply({"params": params}, images)
            if opt_state is None:
                opt_state, opt_update, get_params = linprobe.init((features_x, labels), learning_rate=0.1)
            loss, opt_state = linprobe.train_step(opt_state, features_x, labels, step, opt_update, get_params)
            step += 1
        #     print(f"{loss:0.3f}", end=", ")
        # print()
        acc = linprobe.get_acc(get_params(opt_state), model.apply({"params": params}, dataset.test_data), dataset.test_targets)
        lr_schedule = optimizers.exponential_decay(0.1, decay_rate=0.005, decay_steps=1000)
        print(f"Epoch {i+1}, Acc: {100*acc:.2f}%, Step: {step}, LR: {lr_schedule(step):.6f}")
    return acc


dataset = datasets.CIFAR10(
    batch_size=5000,
    data_location=Path("/homes/ag2198/data"),
    include_flip=False,
    data_aug=None,
    key=None,
    randomise=False,
    data_limit=None,
)

if EXPERIMENT == "data-aug":
    # exp with getting close to full performance
    big_k = 5000000
    k_str = "5 \cdot 10^6"
    runs = {1: "peach-river-1270", big_k: "deep-oath-1267"}
    total_epochs = 200000
    # STEP = 10000
    STEP = 400
elif EXPERIMENT == "adam":
    big_k = 200000
    k_str = "2 \cdot 10^5"
    runs = {1: "gallant-flower-1272", big_k: "earthy-deluge-1271"}
    total_epochs = 20000
    STEP = 1000
lr = 40
N = 100

if LOAD:
    if gen_filters:
        with open(f"filters-{EXPERIMENT}.pkl", "rb") as f:
            all_filters = pickle.load(f)
    if gen_patches:
        with open(f"patches-{EXPERIMENT}-L{EXTRACT_LAYER}.pkl", "rb") as f:
            all_patches = pickle.load(f)
    if gen_linprobe:
        with open(f"linprobe-{EXPERIMENT}-L{EXTRACT_LAYER}.pkl", "rb") as f:
            all_linprobe = pickle.load(f)
else:
    all_filters = {}
    all_patches = {}
    all_mine = {}
    all_linprobe = {}
    for k in (big_k, 1):
        all_filters[k] = []
        all_patches[k] = []
        all_mine[k] = []
        all_linprobe[k] = []
        run = runs[k]
        # for epoch in [0, 100, 1000, 10000, 100000]:
        # for epoch in [0]:
        for epoch in range(0, total_epochs+1, STEP):
            print("Epoch", epoch)
            opt = optimizers.adam if k == big_k and EXPERIMENT == "adam" else optimizers.sgd
            params = load_params_at_epoch(run, epoch, opt)
            filters = params.params["Conv_0"]["kernel"]
            all_filters[k].append(filters)

            if gen_linprobe:
                acc = train_linprobe(params, dataset)
                print("Epoch:", epoch, "\tLin probe acc:", acc)
                all_linprobe[k].append(acc)

            if gen_filters and EXTRACT_LAYER == 1:
                fig = plt.figure(figsize=(10, 10))
                r, c = 7, 8
                for i in range(50):
                    image = filters[:, :, :, i]
                    fig.add_subplot(r, c, i+1)
                    plt.imshow(image + 0.5)
                    plt.axis('off')
                    plt.title(f"Filter {i}")
                fig.tight_layout()
                fig.savefig(f"images-{EXPERIMENT}/perepoch/filters_k{k}-L{EXTRACT_LAYER}_epoch{epoch}.png")

            if gen_patches:
                patches = get_image_patches(params, dataset, model, extract_layer=EXTRACT_LAYER)
                all_patches[k].append(onp.array([p[1] for p in patches]).transpose((1,2,3,0)))
                fig = plt.figure(figsize=(10, 10), facecolor="gray")
                r, c = 7, 8
                for i in range(50):
                    val, image = patches[i]
                    # print(f"Filter {i}: {val}, {image.shape}")
                    fig.add_subplot(r, c, i+1)
                    plt.imshow(image/256)
                    plt.axis('off')
                    plt.title(f"Filter {i}")
                fig.tight_layout()
                fig.savefig(f"images-{EXPERIMENT}/perepoch/filter_patches_L{EXTRACT_LAYER}_k{k}_epoch{epoch}.png")

        if EXTRACT_LAYER == 1:
            fig = plt.figure(figsize=(10, 10))
            r, c = 5, 1
            timeseries = onp.stack([f.reshape(25, 3, 50) for f in all_filters[k]]).swapaxes(0, 1)
            # for i in range(r):
            for i, f in enumerate((2, 8, 10, 19, 20)):
                fig.add_subplot(r, c, i+1)
                plt.imshow(timeseries[:, :, :, f] + 0.5)
                plt.axis("off")
                # plt.title(f"Filter {i}")
            fig.tight_layout()
            fig.savefig(f"images-{EXPERIMENT}/filters_k{k}.png")

    if gen_linprobe:
        for k, v in all_linprobe.items():
            all_linprobe[k] = [x.tolist() for x in all_linprobe[k]]


add_border = lambda f: lambda x: onp.concatenate([f(x), onp.ones((1,) + x.shape[1:])])
make_image_sequence = lambda xs, f=(lambda a:a): onp.concatenate(list(map(add_border(f), xs))).swapaxes(0, 1)

if gen_linprobe:
    with open(f"linprobe-{EXPERIMENT}-L{EXTRACT_LAYER}.pkl", "wb") as f:
        pickle.dump(all_linprobe, f)
    fig = plt.figure(figsize=(10, 6))
    fig.add_subplot(1, 1, 1)
    plt.plot(onp.arange(total_epochs+1, step=STEP)*400, all_linprobe[1], label="K=1")
    plt.plot(onp.arange(total_epochs+1, step=STEP)*400, all_linprobe[big_k], label=f"K={big_k}")
    plt.xlabel("Step")
    plt.ylabel("Linear probe test accuracy")
    plt.title(f"Layer {EXTRACT_LAYER} Features")
    plt.legend()
    fig.savefig(f"images-{EXPERIMENT}/lin_probe_L{EXTRACT_LAYER}.png")
    fig.savefig(f"images-{EXPERIMENT}/lin_probe_L{EXTRACT_LAYER}.pdf")

if EXTRACT_LAYER == 1 and gen_filters:
    with open(f"filters-{EXPERIMENT}.pkl", "wb") as f:
        pickle.dump(all_filters, f)
    N = 5
    for start in range(0, 50, N):
        FILTERS = range(start, start+N)
        # FILTERS = (2, 8, 10, 19, 20
        fig = plt.figure(figsize=(10, 4))
        r, c = len(FILTERS), 1
        timeseries1 = make_image_sequence(all_filters[1], lambda x: x + 0.5)
        timeseries2 = make_image_sequence(all_filters[big_k], lambda x: x + 0.5)
        for i, f in enumerate(FILTERS):
            fig.add_subplot(r, c, i+1)
            plt.imshow(onp.concatenate((timeseries1[:,:,:,f], onp.ones((1,)+timeseries1.shape[1:-1]), timeseries2[:,:,:,f])))
            plt.ylabel(f"Filter {f}")
            plt.yticks([2, 8], ["$K=1$", f"$K={k_str}$"])
            plt.xticks([], [])
            ax = plt.gca()
            for tick in ax.axes.get_yticklines():
                tick.set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
        fig.tight_layout()
        fig.savefig(f"images-{EXPERIMENT}/filter_{start}_comparison.png")
        fig.savefig(f"images-{EXPERIMENT}/filter_{start}_comparison.pdf")

    all_filters_diff = {}
    for k, filters in all_filters.items():
        base = filters[0]
        all_filters_diff[k] = [f-base for f in filters]
    N = 3
    start = 0
    FILTERS = range(start, start+N)
    # FILTERS = (2, 8, 10, 19, 20
    fig = plt.figure(figsize=(7, 2))
    r, c = len(FILTERS), 1
    timeseries1 = make_image_sequence(all_filters_diff[1], lambda x: x + 0.5)
    timeseries2 = make_image_sequence(all_filters_diff[big_k], lambda x: x + 0.5)
    for i, f in enumerate(FILTERS):
        fig.add_subplot(r, c, i+1)
        plt.imshow(onp.concatenate((timeseries1[:,:,:,f], onp.ones((1,)+timeseries1.shape[1:-1]), timeseries2[:,:,:,f])))
        plt.ylabel(f"Filter {f}")
        plt.yticks([2, 8], ["$K=1$", f"$K={k_str}$"])
        plt.xticks([], [])
        ax = plt.gca()
        for tick in ax.axes.get_yticklines():
            tick.set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
    fig.tight_layout()
    fig.savefig(f"images-{EXPERIMENT}/filter_diff_paper_comparison.png")
    fig.savefig(f"images-{EXPERIMENT}/filter_diff_paper_comparison.pdf")

if gen_patches:
    with open(f"patches-{EXPERIMENT}-L{EXTRACT_LAYER}.pkl", "wb") as f:
        pickle.dump(all_patches, f)
    for start in range(0, 50, 10):
        FILTERS = range(start, start+10)
        # FILTERS = (2, 8, 10, 19, 20
        fig = plt.figure(figsize=(10, 10))
        r, c = len(FILTERS), 1
        timeseries1 = make_image_sequence(all_patches[1], lambda x:x/255)
        timeseries2 = make_image_sequence(all_patches[big_k], lambda x:x/255)
        for i, f in enumerate(FILTERS):
            fig.add_subplot(r, c, i+1)
            plt.imshow(onp.concatenate((timeseries1[:,:,:,f], onp.ones((1,)+timeseries1.shape[1:-1]), timeseries2[:,:,:,f])))
            plt.ylabel(f"Filter {f}")
            plt.yticks([2, 8], ["K=1", f"K={k_str}"])
            plt.xticks([], [])
            ax = plt.gca()
            for tick in ax.axes.get_yticklines():
                tick.set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
        fig.tight_layout()
        fig.savefig(f"images-{EXPERIMENT}/patch_L{EXTRACT_LAYER}_{start}_comparison.png")
        fig.savefig(f"images-{EXPERIMENT}/patch_L{EXTRACT_LAYER}_{start}_comparison.pdf")
        
fig, axes= plt.subplots(1, 2, sharey=True)
fig.set_size_inches(7, 2)
axes[0].set_ylabel("Linear probe test accuracy")
for i, layer in enumerate((1, 2)):
    with open(f"linprobe-{EXPERIMENT}-L{layer}.pkl", "rb") as f:
        linprobes = pickle.load(f)
    axes[i].plot(onp.arange(total_epochs+1, step=STEP)*400, linprobes[1], label="$K=1$", c="b")
    axes[i].plot(onp.arange(total_epochs+1, step=STEP)*400, linprobes[big_k], label=f"$K={k_str}$", c="r")
    axes[i].set_xlabel("Step")
    axes[i].set_title(f"Layer {layer} Features")
axes[i].legend()
fig.tight_layout()
fig.savefig(f"images-{EXPERIMENT}/lin_probe.png")
fig.savefig(f"images-{EXPERIMENT}/lin_probe.pdf")
