#from accelerate import Accelerator
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.checkpoints import get_latest_checkpoint
import yaml
from models.resnet_wrapper import ResNetWithHead
from data.dataloader import get_dataloader
import os
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from umap import UMAP
from tqdm import tqdm
from sklearn.decomposition import PCA
import numpy as np
import seaborn as sns

def extract_latent(model, dataloader):
    # accelerator = Accelerator()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    latents = []
    labels = []
    all_features = []
    model.to(device)
    model.eval()
    with torch.no_grad():
        for x, y in tqdm(dataloader,leave=False):
            features = model.backbone(x.to(device))
            latent = model.projector(features)
            latents.append(latent.cpu())
            all_features.append(features.cpu())
            labels.append(y.cpu())
    
    return torch.cat(latents).numpy(),torch.cat(labels).numpy(), torch.cat(all_features).numpy()

if __name__ == '__main__':
    with open("configs/eval.yaml", "r") as f:
        config = yaml.safe_load(f)
    dataloader = get_dataloader(config['dataset'],split='valid')
    epoch = config['eval']['epoch']
    for backbone in config['backbones']:
        name = backbone['name']
        for latent_dim in backbone['latent_dims']:
            model = ResNetWithHead(backbone_name=name, latent_dim=latent_dim, num_classes=2)
            model.backbone.load_state_dict(torch.load(f"checkpoints/base/{config['base_model']['dataset']}/cross_entropy/{name}_latent{latent_dim}_epoch{epoch}.pt")['backbone'])
            model.projector.load_state_dict(torch.load(f"checkpoints/base/{config['base_model']['dataset']}/cross_entropy/{name}_latent{latent_dim}_epoch{epoch}.pt")['projector'])
            for module in model.backbone.modules():
                if isinstance(module, torch.nn.BatchNorm2d):
                    print(module.running_mean.mean().item(), module.running_var.mean().item())

            latents, labels, features = extract_latent(model, dataloader)
            print(features.max(), features.min(),features.std())

            ds = config['base_model']['dataset']

            pca = PCA()
            pca.fit(latents)
            cumulative_variance = np.cumsum(pca.explained_variance_ratio_)

            def effective_dimension(cumulative_variance, threshold):
                return np.searchsorted(cumulative_variance, threshold) + 1

            ed_95 = effective_dimension(cumulative_variance, 0.95)
            ed_99 = effective_dimension(cumulative_variance, 0.99)
            ed_999 = effective_dimension(cumulative_variance, 0.999)
            ed_9999 = effective_dimension(cumulative_variance, 0.9999)
            ed_99999 = effective_dimension(cumulative_variance, 0.99999)
            ed_999999 = effective_dimension(cumulative_variance, 0.999999)
            ed_9999999 = effective_dimension(cumulative_variance, 0.9999999)

            print(f"Effective Dimension at 95%: {ed_95}")
            print(f"Effective Dimension at 99%: {ed_99}")
            print(f"Effective Dimension at 99.9%: {ed_999}")
            print(f"Effective Dimension at 99.99%: {ed_9999}")
            print(f"Effective Dimension at 99.999%: {ed_99999}")
            print(f"Effective Dimension at 99.9999%: {ed_999999}")
            print(f"Effective Dimension at 99.99999%: {ed_9999999}")

            tsne = TSNE(n_components=2, random_state=42)
            z = tsne.fit_transform(latents)
            umap_z = UMAP(n_components=2, random_state=42).fit_transform(latents)

            plt.figure(figsize=(8, 4))
            sns.histplot(latents.flatten(), bins=50, kde=True)
            plt.title("Distribution of Latent Activations")
            plt.savefig(f"latent_distribution_{ds}_{name}_{latent_dim}.png")
            plt.show()

            plt.figure(figsize=(8, 4))
            sns.histplot(features.flatten(), bins=50, kde=True)
            plt.title("Distribution of Backbone Features (Pre-Projector)")
            plt.savefig(f"backbone_feature_distribution_{ds}_{name}_{latent_dim}.png")
            plt.show()


            plt.figure(figsize=(8, 4))
            plt.plot(pca.explained_variance_ratio_)
            plt.title("PCA Variance Spectrum")
            plt.savefig(f"pca_spectrum_{ds}_{name}_{latent_dim}.png")
            plt.show()

            plt.figure(figsize=(8,6))
            scatter = plt.scatter(z[:, 0], z[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
            plt.legend(*scatter.legend_elements(), title="Class")
            plt.title(f"t-SNE of Latent Representations, {name}-{latent_dim}")
            plt.savefig(f"latent_tsne_{ds}_{name}_{latent_dim}_{epoch}.png")

            plt.figure(figsize=(8,6))
            scatter = plt.scatter(umap_z[:, 0], umap_z[:, 1], c=labels, cmap='coolwarm', alpha=0.7)
            plt.legend(*scatter.legend_elements(), title="Class")
            plt.title(f"UMAP of Latent Representations, {name}-{latent_dim}")
            plt.savefig(f"latent_umap_{ds}_{name}_{latent_dim}_{epoch}.png")
