import matplotlib.pyplot as plt
import seaborn as sns
import torch
from models.conv import conv_base, MLP
from torch import nn
from torch import optim
from tqdm import tqdm

from data.load import get_cl
from utils.utils import Monitor

run_eval = False
n_splits = 10
output_size = 100
dropout_prob = 0.8
input_size = 1024
hidden_size = 400
n_samples = 100
num_epochs = 200

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if run_eval:
    train_dataloaders, test_dataloaders = get_cl(
        "cifar100",
        n_tasks=10,
        n_classes=100,
        batch_size=256,
        dir_store="store/datasets",
    )
    for dropout_prob, name in zip([0.0, 0.8], ["regular", "dropout"]):
        net = MLP(hidden_size, output_size, input_size, dropout_prob).to(device)
        criterion = nn.CrossEntropyLoss()

        test_input = next(iter(test_dataloaders[0]))[0].to(device)

        time = 0

        base = conv_base().to(device)
        base.load_state_dict(torch.load("store/save_model/save_models_full"))
        base.eval()

        monitor = Monitor(device=device, base=base, name=name)

        optimizer = optim.Adam(net.parameters(), lr=0.0001)

        dataloader_train = train_dataloaders[0]
        for epoch in tqdm(range(num_epochs)):
            for inputs, labels in dataloader_train:
                net.train()
                time += 1
                inputs, labels = inputs.to(device), labels.to(device)
                with torch.no_grad():
                    inputs = base(inputs)
                optimizer.zero_grad()
                preds = net(inputs)
                loss = criterion(preds, labels)
                loss.backward()
                optimizer.step()

                monitor.record(
                    test_dataloader=test_dataloaders[0],
                    train_dataloader=train_dataloaders[0],
                    model=net,
                )

        monitor.print_all()

sns.set_style("whitegrid")
sns.set_context("talk")
x = [1, 2, 4, 8, 16, 32]
sns.color_palette("deep")
methods = {
    "Oracle": (torch.tensor([0.774] * 6), "--", "o", "blue"),
    "Finetune": (torch.tensor([0.0815] * 6), "--", "^", "cyan"),
    "Joint": (torch.tensor([0.38236] * 6), "--", "s", "purple"),
    "ER": (torch.tensor([0.2079] * 6), "-", "D", "green"),  # Diamond marker
    "EWC": (torch.tensor([0.08185] * 6) + 0.01, "-", "p", "orange"),  # Pentagon marker
    "Drift": (
        torch.tensor([0.3242, 0.4079, 0.4739, 0.6544, 0.7466, 0.7737]),
        "-",
        "*",
        "red",
    ),  # Star marker
    "GenCl": (
        torch.tensor([0.2777, 0.3191, 0.4067, 0.509399, 0.6222, 0.67409]),
        "-",
        "X",
        "blue",
    ),  # X marker
}

# Plot
plt.figure(figsize=(10, 6))
plt.subplots_adjust(
    top=0.911, bottom=0.155, left=0.105, right=0.795, hspace=0.2, wspace=0.2
)
for method, (tensor, linestyle, marker, color) in methods.items():
    sns.lineplot(
        x=x,
        y=tensor.view(-1).numpy(),
        label=method,
        linestyle=linestyle,
        marker=marker,
        color=color,
    )

# Formatting
plt.xlabel("Test batch size")
plt.ylabel("Accuracy%")
plt.title("Test accuracy")
plt.legend(loc="lower right", bbox_to_anchor=(1.3, 0.5))
plt.grid(True)
plt.xscale("log", base=2)
plt.xticks(x, x)
plt.tight_layout()

# Show plot
plt.savefig("store/figs/accuracy.png")
plt.show()

#
#
#
# # Set Seaborn style and context
# def moving_average(data, window_size):
#     return np.convolve(data, np.ones(window_size)/window_size, mode='valid')
#
# # Set Seaborn style and context
# sns.set_style("whitegrid")
# sns.set_context("talk")
#
# # Define colors
# colors = {'regular': 'blue', 'dropout': 'red'}
#
# print_name = {'regular': 'Regular', 'dropout': 'Dropout'}
#
#
#
#
# for name in ['regular', 'dropout']:
#     plt.figure(figsize=(10, 6))
#     sns.set_context("talk")
#     embeddings = torch.load(f'./store/tensor/{name}_representation.pt')
#     embeddings = embeddings[:, :3]
#
#     timestamp = np.arange(0, embeddings.shape[0], 10)
#     embeddings = embeddings[timestamp, :]
#     for i in range(embeddings.shape[1]):
#         embedding = embeddings[:, i]
#         # Plot original data with reduced alpha, dashed line style, and specified color
#         # sns.lineplot(x=timestamp, y=embedding.numpy(), label=None, linestyle='--')
#
#         # # Apply moving average and plot with specified color
#         window_size = 10
#         smoothed_embedding = moving_average(embedding.numpy(), window_size=window_size)
#         sns.lineplot(x=timestamp[window_size-1:], y=smoothed_embedding, label=f'embedding {i}')
#
#
#
#     # Title and labels
#     plt.title(f"{print_name[name]}: Embedding of first three dimensions")
#     plt.xlabel("Time")
#     plt.ylabel("Value")
#     plt.legend()
#
#     plt.savefig(f'store/figs/{name}_embedding.png')
#     plt.show()
#
#
# plt.figure(figsize=(10, 6))
#
# for name in ['regular', 'dropout']:
#     representation = torch.load(f'./store/tensor/{name}_representation.pt')
#     start = int(0.5 * len(representation))
#     cos_dist = torch.tensor([F.cosine_similarity(representation[start], representation[i], dim=0) for i in
#                              np.arange(start, representation.shape[0], 5)])
#
#     # Plot original data with reduced alpha, dashed line style, and specified color
#     sns.lineplot(x=range(len(cos_dist)), y=cos_dist.numpy(), label=None, linestyle='--', alpha=0.3)
#
#     # Apply moving average and plot with specified color
#     smoothed_cos_dist = moving_average(cos_dist.numpy(), window_size=5)
#     sns.lineplot(x=range(len(smoothed_cos_dist)), y=smoothed_cos_dist, label=print_name[name], color=colors[name])
#
# # Set y-axis limits
# plt.ylim(0, 1)
#
# # Title and labels
# plt.title("Cosine similarity over time")
# plt.xlabel("Time")
# plt.ylabel("Cosine distance")
# plt.legend()
#
# plt.savefig('store/figs/cos_dist_comparison.png')
# plt.show()
#
