import torch
import torchvision
import plotly.express as px
from tqdm import tqdm
import einops
import numpy as np
import os
from vit_prisma.models.base_vit import HookedViT
from vit_prisma.sae.training.activations_store import VisionActivationsStore
from torch.utils.data import DataLoader
from dataclasses import dataclass
from vit_prisma.sae.config import VisionModelSAERunnerConfig
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_index_to_name
import matplotlib.pyplot as plt
from vit_prisma.sae.sae import StandardSparseAutoencoder as SparseAutoencoder
from typing import List, Dict, Tuple
from tqdm import tqdm

torch.set_grad_enabled(False)

model_name = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K"
model = HookedViT.from_pretrained(model_name).to('cuda')

@dataclass
class EvalConfig(VisionModelSAERunnerConfig):
    sae_path: str = ''
    model_name: str = "open-clip:laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K" 
    model_type: str =  "clip"
    patch_size: str = 32

    dataset_path = "IMGNET_PATH"
    dataset_train_path: str = "IMGNET_TRAIN_PATH"
    dataset_val_path: str = "IMGNET_VAL_PATH"

    verbose: bool = True

    device: bool = 'cuda'

    eval_max: int = 50_000 
    batch_size: int = 32


from vit_prisma.transforms.model_transforms import get_clip_val_transforms

data_transforms = get_clip_val_transforms()

from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader

class IndexedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
            
    def __getitem__(self, index):
        image, label = self.dataset[index]
        return image, label, index

    def __len__(self):
        return len(self.dataset)


val_data = IndexedDataset(DatasetFolder(
            root=dataset_val_path,
            loader=default_loader,
            extensions=('.jpg', '.jpeg', '.png'),
            transform=data_transforms
))



print(f"Validation data length: {len(val_data)}")

val_dataloader = DataLoader(val_data, batch_size=32, shuffle=True, num_workers=4)

saes = [
     # (IMG_OUTPUT_DIR, SAE_CHECKPOINT_PATH) -> Tuple format, for the desired layers
    ]

ind_to_name = get_imagenet_index_to_name()

@torch.no_grad()
def get_feature_probability(images, model, sparse_autoencoder):
    _, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
    sae_out, feature_acts, *_ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point][:,1:,:]
    )

    return (feature_acts.abs() > 0).float().flatten(0, 1)


def process_dataset(val_dataloader, model, sparse_autoencoder, cfg):
    total_acts = None
    total_tokens = 0
        
    for idx, batch in tqdm(enumerate(val_dataloader), total=cfg.eval_max//cfg.batch_size):
        images = batch[0]

        images = images.to(cfg.device)
        sae_activations = get_feature_probability(images, model, sparse_autoencoder)
            
        if total_acts is None:
            total_acts = sae_activations.mean(0)
        else:
            total_acts += sae_activations.mean(0)
            
        total_tokens += sae_activations.shape[0]
            
           
    return total_acts, total_tokens


def calculate_log_frequencies(total_acts, total_tokens):
    feature_probs = total_acts / total_tokens
    log_feature_probs = torch.log10(feature_probs)
    
    return log_feature_probs.cpu().numpy()

def to_numpy(tensor):
        """
        Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
        """
        if isinstance(tensor, np.ndarray):
            return tensor
        elif isinstance(tensor, (list, tuple)):
            array = np.array(tensor)
            return array
        elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)):
            return tensor.detach().cpu().numpy()
        elif isinstance(tensor, (int, float, bool, str)):
            return np.array(tensor)
        else:
            raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")


@torch.no_grad()
def compute_feature_activations(
        images: torch.Tensor,
        model: torch.nn.Module,
        sparse_autoencoder: torch.nn.Module,
        encoder_weights: torch.Tensor,
        encoder_biases: torch.Tensor,
        feature_ids: List[int],
        feature_categories: List[str],
        top_k: int = 10
    ) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Compute the highest activating tokens for given features in a batch of images.
        
        Args:
            images: Input images
            model: The main model
            sparse_autoencoder: The sparse autoencoder
            encoder_weights: Encoder weights for selected features
            encoder_biases: Encoder biases for selected features
            feature_ids: List of feature IDs to analyze
            feature_categories: Categories of the features
            top_k: Number of top activations to return per feature

        Returns:
            Dictionary mapping feature IDs to tuples of (top_indices, top_values)
        """
        _, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
        
        layer_activations = cache[sparse_autoencoder.cfg.hook_point]
        if sparse_autoencoder.cfg.cls_token_only:
            layer_activations = layer_activations[:,0:1,:]
        batch_size, seq_len, _ = layer_activations.shape
        flattened_activations = einops.rearrange(layer_activations, "batch seq d_mlp -> (batch seq) d_mlp")
        
        sae_input = flattened_activations - sparse_autoencoder.b_dec
        feature_activations = einops.einsum(sae_input, encoder_weights, "... d_in, d_in n -> ... n") + encoder_biases
        feature_activations = torch.nn.functional.relu(feature_activations)
        
        reshaped_activations = einops.rearrange(feature_activations, "(batch seq) d_in -> batch seq d_in", batch=batch_size, seq=seq_len)
        cls_token_activations = reshaped_activations[:, 0, :]
        mean_image_activations = reshaped_activations.mean(1)

        top_activations = {}
        for i, (feature_id, feature_category) in enumerate(zip(feature_ids, feature_categories)):
            if "CLS_" in feature_category:
                top_values, top_indices = cls_token_activations[:, i].topk(top_k)
            else:
                top_values, top_indices = mean_image_activations[:, i].topk(top_k)
            top_activations[feature_id] = (top_indices, top_values)
        
        return top_activations

def find_top_activations(
        val_dataloader: torch.utils.data.DataLoader,
        model: torch.nn.Module,
        sparse_autoencoder: torch.nn.Module,
        cfg: object,
        interesting_features_indices: List[int],
        interesting_features_category: List[str],
        top_k: int = 16,
    ) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Find the top activations for interesting features across the validation dataset.

        Args:
            val_dataloader: Validation data loader
            model: The main model
            sparse_autoencoder: The sparse autoencoder
            cfg: Configuration object
            interesting_features_indices: Indices of interesting features
            interesting_features_category: Categories of interesting features

        Returns:
            Dictionary mapping feature IDs to tuples of (top_values, top_indices)
        """
        max_samples = cfg.eval_max

        top_activations = {i: (None, None) for i in interesting_features_indices}
        encoder_biases = sparse_autoencoder.b_enc[interesting_features_indices]
        encoder_weights = sparse_autoencoder.W_enc[:, interesting_features_indices]

        processed_samples = 0
        for batch_images, _, batch_indices in tqdm(val_dataloader, total=max_samples // cfg.batch_size):
            batch_images = batch_images.to(cfg.device)
            batch_indices = batch_indices.to(cfg.device)
            batch_size = batch_images.shape[0]

            batch_activations = compute_feature_activations(
                batch_images, model, sparse_autoencoder, encoder_weights, encoder_biases,
                interesting_features_indices, interesting_features_category, top_k
            )

            for feature_id in interesting_features_indices:
                new_indices, new_values = batch_activations[feature_id]
                new_indices = batch_indices[new_indices]
                
                if top_activations[feature_id][0] is None:
                    top_activations[feature_id] = (new_values, new_indices)
                else:
                    combined_values = torch.cat((top_activations[feature_id][0], new_values))
                    combined_indices = torch.cat((top_activations[feature_id][1], new_indices))
                    _, top_k_indices = torch.topk(combined_values, top_k)
                    top_activations[feature_id] = (combined_values[top_k_indices], combined_indices[top_k_indices])

            processed_samples += batch_size
            if processed_samples >= max_samples:
                break

        return {i: (values.detach().cpu(), indices.detach().cpu()) 
                for i, (values, indices) in top_activations.items()}


torch.no_grad()
def get_heatmap(
            image,
            model,
            sparse_autoencoder,
            feature_id,
    ): 
        image = image.to(cfg.device)
        _, cache = model.run_with_cache(image.unsqueeze(0))

        post_reshaped = einops.rearrange(cache[sparse_autoencoder.cfg.hook_point][:,1:,:], "batch seq d_mlp -> (batch seq) d_mlp")
        sae_in =  post_reshaped - sparse_autoencoder.b_dec
        acts = einops.einsum(
                sae_in,
                sparse_autoencoder.W_enc[:, feature_id],
                "x d_in, d_in -> x",
            )
        return torch.tensor(acts, dtype=torch.float32)
        
def image_patch_heatmap(activation_values,image_size=224, pixel_num=14):
        activation_values = activation_values.detach().cpu().numpy()
        activation_values = activation_values[1:]
        activation_values = activation_values.reshape(pixel_num, pixel_num)

        # Create a heatmap overlay
        heatmap = np.zeros((image_size, image_size))
        patch_size = image_size // pixel_num

        for i in range(pixel_num):
            for j in range(pixel_num):
                heatmap[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = activation_values[i, j]

        return heatmap


        # Removing axes

def denormalize(image, mean, std):
        mean = torch.tensor(mean).view(-1, 1, 1).to(image.device)
        std = torch.tensor(std).view(-1, 1, 1).to(image.device)
        return image * std + mean



for sae_obj in saes:

    cfg = EvalConfig()

    cfg.sae_path = sae_obj[1]

    sparse_autoencoder = SparseAutoencoder(cfg).load_from_pretrained(cfg.sae_path)
    sparse_autoencoder.to(cfg.device)
    sparse_autoencoder.eval()  

    total_acts, total_tokens = process_dataset(val_dataloader, model, sparse_autoencoder, cfg)

    log_frequencies = calculate_log_frequencies(total_acts, total_tokens)

    log_freq = torch.Tensor(log_frequencies)

    min_log_freq = log_freq.min().item()
    max_log_freq = log_freq.max().item()

    intervals = [
        (-8, -6),
        (-6, -5),
        (-5, -4),
        (-4, -3),
        (-3, -2),
        (-2, -1),
        (-float('inf'), -8), 
        (-1, float('inf')) 
    ]

    conditions = [torch.logical_and(log_freq >= lower, log_freq < upper) for lower, upper in intervals]
    condition_texts = [
        f"TOTAL_logfreq_[{lower},{upper}]" for lower, upper in intervals
    ]

    condition_texts[-2] = condition_texts[-2].replace('-inf', '-∞')
    condition_texts[-1] = condition_texts[-1].replace('inf', '∞')

    log_freq = log_freq.to('cuda')

    interesting_features_indices = []
    interesting_features_values = []
    interesting_features_category = []
    number_features_per = 50
    for condition, condition_text in zip(conditions, condition_texts):
        

        potential_indices = torch.nonzero(condition, as_tuple=True)[0]

        sampled_indices = potential_indices[torch.randperm(len(potential_indices))[:number_features_per]].to('cuda')
        

        values = log_freq[sampled_indices]

        interesting_features_indices = interesting_features_indices + sampled_indices.tolist()
        interesting_features_values = interesting_features_values + values.tolist()

        interesting_features_category = interesting_features_category + [f"{condition_text}"]*len(sampled_indices)

    print(set(interesting_features_category))

    top_activations_per_feature = find_top_activations(
        val_dataloader, model, sparse_autoencoder, cfg,
        interesting_features_indices, interesting_features_category
    )


    for feature_ids, cat, logfreq in tqdm(zip(top_activations_per_feature.keys(), interesting_features_category, interesting_features_values), total=len(interesting_features_category)):
        max_vals, max_inds = top_activations_per_feature[feature_ids]
        images = []
        model_images = []
        gt_labels = []
        for bid, v in zip(max_inds, max_vals):

            image, label, image_ind = val_data[bid]

            assert image_ind.item() == bid
            images.append(image)

            model_image, _, _ = val_data[bid]
            model_images.append(model_image)
            gt_labels.append(ind_to_name[str(label)][1])
        
        grid_size = int(np.ceil(np.sqrt(len(images))))
        fig, axs = plt.subplots(int(np.ceil(len(images)/grid_size)), grid_size, figsize=(15, 15))
        name=  f"Category: {cat},  Feature: {feature_ids}"
        fig.suptitle(name)#, y=0.95)
        for ax in axs.flatten():
            ax.axis('off')
        complete_bid = []

        for i, (image_tensor, label, val, bid,model_img) in enumerate(zip(images, gt_labels, max_vals,max_inds,model_images )):
            if bid in complete_bid:
                continue 
            complete_bid.append(bid)


            row = i // grid_size
            col = i % grid_size
            heatmap = get_heatmap(model_img,model,sparse_autoencoder, feature_ids )
            heatmap = image_patch_heatmap(heatmap, pixel_num=224//cfg.patch_size)

            mean = [
                0.485, 
                0.456, 
                0.406
            ]
            std = [
                0.229, 
                0.224, 
                0.225
            ]

            # Later in your plotting loop:
            display = denormalize(image_tensor, mean, std)
            display = display.clamp(0, 1).numpy().transpose(1, 2, 0)
            axs[row, col].imshow(display)

            has_zero = False
            

            axs[row, col].imshow(display)
            axs[row, col].imshow(heatmap, cmap='viridis', alpha=0.3)  # Overlaying the heatmap
            axs[row, col].set_title(f"{label} {val.item():0.03f} {'class token!' if has_zero else ''}")  
            axs[row, col].axis('off')  

        plt.tight_layout()

        folder = os.path.join(sae_obj[0], f"{cat}")
        os.makedirs(folder, exist_ok=True)
        plt.savefig(os.path.join(folder, f"neglogfreq_{-logfreq}feauture_id:{feature_ids}.png"))
        plt.close()

    del sparse_autoencoder
    torch.cuda.empty_cache()
