
import numpy as np

import matplotlib.pyplot as plt
import torch

from dataset import read_dataset
from awp import MLPBinary2Logits
from fgsm import perturb_dataset

import torch
import torch.nn as nn
import torch.nn.functional as F

from math import sqrt
from typing import Dict, Optional

import os


# datasets = ['iris', 'penguin', 'digits4', 'seeds', 'wine', 'fico']
datasets = ['compas']
# ctor = lambda d: MLPBinary2Logits(d=d, hidden=16, depth=2, dropout=0.0)
# ctor = lambda d: MLPBinary2Logits(d=d, hidden=8, depth=2, dropout=0.0)
# ctor = lambda d: MLPBinary2Logits(d=d, hidden=10, depth=3, dropout=0.0)
ctor = lambda d: MLPBinary2Logits(d=d, hidden=20, depth=4, dropout=0.0)
# datasets = ['seeds']
# ctor = lambda d: MLPBinary2Logits(d=d, hidden=25, depth=3, dropout=0.0)
# datasets = ['wine']
# ctor = lambda d: MLPBinary2Logits(d=d, hidden=25, depth=2, dropout=0.0)
# etas = [0, 0.1, 0.2, 0.4, 0.8, 1.6]
etas = np.linspace(0, 1, 21)

eps = 0.50

# norm = 2
attack_norm = float('inf')

lossf = lambda model, data, labels: F.cross_entropy(model(data), labels).item()
lossf_name = "cross_entropy"

# lossf = lambda model, data, labels: torch.mean((model.predict(data) == labels).float()).item()
# lossf_name = "misclassification"

prefix = f"_{lossf_name}_{attack_norm}_{eps}"

print(prefix)

def norm(w):
    return sqrt(np.sum(w * w))

def dot(w,wopt):
    return np.sum(w * wopt) / norm(wopt)

@torch.no_grad()
def flatten_state(
    model: torch.nn.Module,
    *,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """
    Flatten model parameters (and optionally buffers) into a single 1D tensor.
    """
    items = []

    # Parameters
    for p in model.parameters():
        if p is None:
            continue
        items.append(p.detach().to(device=device, dtype=dtype).reshape(-1))

    return torch.cat(items, dim=0)


@torch.no_grad()
def weight_similarity(
    model_a: torch.nn.Module,
    model_b: torch.nn.Module,
    *,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
    eps: float = 1e-12,
) -> Dict[str, float]:
    """
    Compute similarity metrics between flattened weights of two models.
    Returns: cosine similarity, L2 distance, and Pearson correlation.
    """
    wa = flatten_state(model_a, device=device, dtype=dtype)
    wb = flatten_state(model_b, device=device, dtype=dtype)

    # Cosine similarity
    denom = (wa.norm(p=2) * wb.norm(p=2)).clamp_min(eps)
    cosine = (wa @ wb / denom).item()

    return cosine


def pattern_similarity(model_a: torch.nn.Module, model_b: torch.nn.Module, X) -> float:
    with torch.no_grad():
        preds_a = model_a.predict(X)
        preds_b = model_b.predict(X)
        return torch.mean((preds_a == preds_b).float()).item()

for dataset in datasets:
    print(dataset)

    X, Y0, Y1 = read_dataset(f'datasets/{dataset}')
    n, d = X.shape

    X = torch.from_numpy(X).float()
    y = torch.from_numpy(Y0).long()

    folder = f"./saved_models/{dataset}_{eps}"
    data = torch.load(f"{folder}/base_model.pt")
    model_opt = ctor(d)
    model_opt.load_state_dict(data['state_dict'], strict=True)
    lossopt = lossf(model_opt, X, y)

    rashomon_models = [model_opt]
    rashomon_losses = [lossf(model_opt, X, y)]
    for i in range(n):
        for c in range(2):
            if os.path.exists(f"{folder}/model_i={i}_c={c}.pt") == False:
                continue
            data = torch.load(f"{folder}/model_i={i}_c={c}.pt")
            snapshots = data['stats']['snapshots']
            for state_dict in snapshots:
                model = ctor(d)
                model.load_state_dict(state_dict, strict=True)
                loss = lossf(model, X, y)
                rashomon_models.append(model)
                rashomon_losses.append(loss)

                if loss < lossopt:
                    model_opt = model
                    lossopt = loss
            # print(data['stats']['final_loss'])
    print(f"Models in Rashomon set: {len(rashomon_models)}")
    

    plt.hist(rashomon_losses, bins=30)
    plt.title(f"{dataset} Dataset")
    plt.xlabel(f"{lossf_name} Loss on Original Dataset")
    plt.ylabel("Count")
    plt.savefig(f'figures/{dataset}{prefix}_rashomon_loss_hist.png')
    plt.clf()

    models = []
    losses = []
    losses_adv = []
    losses_orig_adv = []
    for eta in etas:
        Xadv = perturb_dataset(model_opt, X, y, eta, norm=attack_norm)

        loss_adv = float('inf')
        model_best = None
        for model in rashomon_models:
            loss = lossf(model, Xadv, y)
            if loss < loss_adv:
                model_best = model
                loss_adv = loss

        models.append(model_best)

        loss = lossf(model_best, X, y)
        loss_orig_adv = lossf(model_opt, Xadv, y)

        # if eta == 0:
        #     print(X, Xadv)
        #     print(loss_adv, loss_orig_adv)

        losses.append(loss)
        losses_adv.append(loss_adv)
        losses_orig_adv.append(loss_orig_adv)

    print("Computing Similarities")

    weight_sims = [weight_similarity(model, model_opt) for model in models]
    pattern_sims = [pattern_similarity(model, model_opt, X) for model in models]
    # sims_normed = [dot(w, wopt) / norm(w) for w, _ in models]
    # norms = [norm(w) for w, _ in models]

    # print(weight_sims, pattern_sims)

    np.savez(
        f"results/{dataset}{prefix}_data", 
        etas,
        losses, 
        lossopt,
        losses_adv,
        losses_orig_adv,
        weight_sims, 
        pattern_sims
    )

    print("Computing Adversarial Deltas")

    # Weight Similarity

    plt.plot(etas, weight_sims)

    plt.title(f"{dataset} Dataset")
    plt.xlabel("Eta")
    plt.ylabel("Norm of weights times cosine similarity")

    plt.savefig(f'figures/{dataset}{prefix}_weight_similarity.png')
    plt.clf()

    # Pattern Similarity

    plt.plot(etas, pattern_sims)

    plt.title(f"{dataset} Dataset")
    plt.xlabel("Eta")
    plt.ylabel("Pattern Similarity")

    plt.savefig(f'figures/{dataset}{prefix}_pattern_similarity.png')
    plt.clf()

    # Similarity Normed

    # plt.plot(etas, sims_normed)

    # plt.title(f"{dataset} Dataset")
    # plt.xlabel("Eta")
    # plt.ylabel("Cosine similarity")

    # plt.savefig(f'figures/{dataset}{prefix}_similaritynormed.png')
    # plt.clf()

    # Norms

    # plt.plot(etas, norms)

    # plt.title(f"{dataset} Dataset")
    # plt.xlabel("Eta")
    # plt.ylabel("Weight Norms")

    # plt.savefig(f'figures/{dataset}{prefix}_norms.png')
    # plt.clf()

    # Loss

    plt.plot(etas, losses, label="Adversarial Model")
    plt.plot([etas[0], etas[-1]], [lossopt, lossopt], label="Optimal Model")

    plt.title(f"{dataset} Dataset")
    plt.xlabel("Eta")
    plt.ylabel("Loss of Adversarial Model on Original Dataset")
    plt.legend()

    plt.savefig(f'figures/{dataset}{prefix}_loss.png')
    plt.clf()

    # Adversarial Loss

    plt.plot(etas, losses_adv, label="Adversarial Model")
    plt.plot(etas, losses_orig_adv, label="Optimal Model")

    plt.title(f"{dataset} Dataset")
    plt.xlabel("Eta")
    plt.ylabel("Loss on Adversarial Dataset")
    plt.legend()

    plt.savefig(f'figures/{dataset}{prefix}_advloss.png')
    plt.clf()
