import itertools
import math
from functools import partial

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from cifar_torch import cifar100
from baseline.greedy.reactnet import reactnet
from networkx.algorithms.dag import topological_sort
import networkx as nx
import torch.nn.functional as F


def make_graph(model):
    graph = nx.DiGraph()
    visited = set()

    def traverse(layer_name):
        visited.add(layer_name)
        for dep in model.dependent_layers(layer_name):
            graph.add_edge(dep, layer_name)
            if dep not in visited:
                traverse(dep)

    traverse('fc')
    return graph


def dataset_cov(train_loader, extract_patches, dim, sample_portion=1.0):
    count = torch.zeros((), device='cuda', dtype=torch.long)
    mean = torch.zeros(dim, device='cuda', dtype=torch.float64)
    cov = torch.zeros(dim, dim, device='cuda', dtype=torch.float64)
    iters = math.ceil(len(train_loader) * sample_portion)
    for X, y in itertools.islice(tqdm(train_loader), iters):
        X = X.cuda(non_blocking=True)
        patches = extract_patches(X)

        other_count = patches.size(1)
        other_mean = patches.double().mean(1)
        other_cov = torch.cov(patches)

        if count == 0:
            count = other_count
            mean.copy_(other_mean)
            cov.copy_(other_cov)
            continue

        merged_count = count + other_count
        count_corr = (other_count * count) / merged_count

        flat_mean_diff = other_mean - mean
        mean += flat_mean_diff * other_count / merged_count

        mean_diffs = torch.broadcast_to(flat_mean_diff, cov.shape).T
        cov *= (count / merged_count)
        cov += (
            other_cov * (count / merged_count)
            + mean_diffs * mean_diffs.T * (count_corr / merged_count)
        )

    # TODO reduce

    return cov.float()


@torch.no_grad()
def extract_patches(model, layer, X):
    X = model.inputs_for(layer, X)
    if isinstance(layer, torch.nn.Conv2d):
        X = F.unfold(
            X,
            kernel_size=layer.kernel_size,
            dilation=layer.dilation,
            padding=layer.padding,
            stride=layer.stride,
        )  # (batch, ch*ks*ks, height*width)
        X = X.permute(1, 0, 2)  # (ch*ks*ks, batch, height*width)
        X = X.reshape(X.shape[0], -1)  # (dim, batch*height*width)
    return X


def test_model_orig(model, val_loader: DataLoader):
    model.eval()
    print("Test")
    count = torch.tensor(0, device='cuda')
    loss_accum = torch.tensor(0.0, device='cuda')
    correct_count = torch.tensor(0, device='cuda')
    criterion = torch.nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for X, y in val_loader:
            X = X.cuda(non_blocking=True)
            y = y.cuda(non_blocking=True).flatten()
            logits = model(X)
            loss = criterion(logits, y)
            count += X.shape[0]
            loss_accum += loss.data
            correct_count += (torch.argmax(logits, dim=1) == y).sum()
        loss = (loss_accum / count).item()
        acc = (correct_count / count).item()
        print(f'{loss}')
        print(f'{acc:.4f}')
    return loss, acc


def load_model(path):
    model = reactnet(num_classes=100).cuda()
    checkpoint = torch.load(path, map_location='cpu')
    print(checkpoint['epoch'], checkpoint['best_top1_acc'])
    state_dict = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict, strict=False)
    return model


def main(scale, sl, n_layers=None):
    torch.cuda.set_device('cuda:2')

    train_loader, val_loader = cifar100(batch_size=200, workers=4)

    model = load_model("/d1/xxx/DBQ/baseline/1_step1/models_run1/model_best.pth.tar")

    graph = make_graph(model)
    layers = list(topological_sort(graph))
    if n_layers is None:
        n_layers = len(layers)

    modules = dict(model.named_modules())

    for layer_name in itertools.islice(layers, n_layers):
        print(layer_name)

        if layer_name == 'fc':
            continue

        if layer_name.endswith('_down'):
            cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
        else:
            cur_layers = [modules[layer_name]]

        get_patches = partial(extract_patches, model, cur_layers[0])
        dim = np.prod(cur_layers[0].weight.shape[1:])

        cov = dataset_cov(train_loader, get_patches, dim)
        _, s, V = torch.svd(cov)

        w = torch.cat([layer.weight.data for layer in cur_layers], dim=0)
        w = w.view(w.shape[0], -1)  # (output, input)

        w_hat = V.T @ w.T
        w_hat[sl] += torch.normal(0, 1, size=w_hat[sl].shape, device=w_hat.device) * scale

        w = (V @ w_hat).T
        ws = torch.split(w, [layer.weight.shape[0] for layer in cur_layers], dim=0)
        for layer, w in zip(cur_layers, ws):
            print(layer.weight.shape, w.shape)
            layer.weight.data.copy_(w.reshape(layer.weight.shape))

    return test_model_orig(model, val_loader)


def main_2(scale, sl, n_layers=None):
    torch.cuda.set_device('cuda:2')

    train_loader, val_loader = cifar100(batch_size=200, workers=4)

    model = load_model("/d1/xxx/DBQ/baseline/1_step1/models_run1/model_best.pth.tar")

    graph = make_graph(model)
    layers = list(topological_sort(graph))
    if n_layers is None:
        n_layers = len(layers)

    modules = dict(model.named_modules())

    Vs = dict()
    for layer_name in itertools.islice(layers, n_layers):
        print(layer_name)

        if layer_name == 'fc':
            continue

        if layer_name.endswith('_down'):
            cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
        else:
            cur_layers = [modules[layer_name]]

        get_patches = partial(extract_patches, model, cur_layers[0])
        dim = np.prod(cur_layers[0].weight.shape[1:])

        cov = dataset_cov(train_loader, get_patches, dim)
        _, s, V = torch.svd(cov)

        Vs[layer_name] = V

    for layer_name in itertools.islice(layers, n_layers):
        print(layer_name)

        if layer_name == 'fc':
            continue

        if layer_name.endswith('_down'):
            cur_layers = [modules[layer_name + '1'], modules[layer_name + '2']]
        else:
            cur_layers = [modules[layer_name]]

        w = torch.cat([layer.weight.data for layer in cur_layers], dim=0)
        w = w.view(w.shape[0], -1)  # (output, input)

        V = Vs[layer_name]

        w_hat = V.T @ w.T
        w_hat[sl] += torch.normal(0, 1, size=w_hat[sl].shape, device=w_hat.device) * scale

        w = (V @ w_hat).T
        ws = torch.split(w, [layer.weight.shape[0] for layer in cur_layers], dim=0)
        for layer, w in zip(cur_layers, ws):
            print(layer.weight.shape, w.shape)
            layer.weight.data.copy_(w.reshape(layer.weight.shape))

    return test_model_orig(model, val_loader)


if __name__ == "__main__":
    scales = np.array([0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 1]) * 2

    results = []
    for trial in range(3):
        for sl in [slice(5), slice(-5, None)]:
            for scale in scales:
                r = main_2(scale, sl)
                result = ((trial, sl, scale), r)
                print(result)
                results.append(result)

    torch.save(results, "analysis_result_2.pth")
