from PIL import Image
import os
import torchvision.transforms as transforms
from typing import Iterable
from nesim.utils.hook import ForwardHook
from nesim.utils.getting_modules import get_module_by_name
import torch
from tqdm import tqdm
from einops import rearrange
from nesim.experiments.resnet import create_model_and_scaler

import torchvision.models as models
from nesim.utils.setting_attr import setattr_pytorch_model
from nesim.sparsity.conv import DownsampledConv2d
import math
from nesim.utils.model_info import count_model_parameters
from nesim.utils.l1_sparsity import apply_l1_sparsity_to_model
from nesim.losses.laplacian_pyramid.loss import LaplacianPyramidLoss
from nesim.experiments.resnet import create_val_loader
from nesim.utils.json_stuff import dict_to_json
import torch


imagenet_transforms = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    ]
)

def apply_spatial_pooling(layer_outputs, spatial_pooling: str):
    """
    XXXX
    """
    assert layer_outputs.ndim == 4
    if spatial_pooling == "max":
        output = torch.amax(layer_outputs, dim=(2, 3))
    elif spatial_pooling == "mean":
        output = torch.mean(layer_outputs, dim=(2, 3))
    elif spatial_pooling == "norm":
        output = torch.norm(layer_outputs, dim=(2, 3))
    else:
        raise ValueError(f"Invalid spatial_pooling value: {spatial_pooling}")
    assert output.ndim == 2
    return output

def load_resnet18_checkpoint(
        checkpoints_folder: str, # "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints"
        model_name: str,
        epoch: int
):
    if isinstance(epoch, int):
        checkpoint_path = os.path.join(
            checkpoints_folder,
            f"{model_name}",
            f"epoch_{epoch}.pt"
        )
    else:
        checkpoint_path = os.path.join(
            checkpoints_folder,
            f"{model_name}",
            f"final_weights.pt"
        )
    assert os.path.exists(checkpoint_path), f"Invalid path: {checkpoint_path}"
    model, scaler = create_model_and_scaler(
        "resnet18",
        pretrained = False,
        distributed=False,
        use_blurpool=True,
        gpu = 0
    )
    model.load_state_dict(
        state_dict=torch.load(
            checkpoint_path,
            weights_only=True
        )
    )
    model.eval()
    return model

def load_resnet50_checkpoint(
        checkpoints_folder: str, # "/research/XXXX-1/toponets_resnet50_imagenet_checkpoints"
        model_name: str,
        epoch: int
):
    if isinstance(epoch, int):
        checkpoint_path = os.path.join(
            checkpoints_folder,
            f"{model_name}",
            f"epoch_{epoch}.pt"
        )
    else:
        checkpoint_path = os.path.join(
            checkpoints_folder,
            f"{model_name}",
            f"final_weights.pt"
        )
    assert os.path.exists(checkpoint_path), f"Invalid path: {checkpoint_path}"
    model, scaler = create_model_and_scaler(
        "resnet50",
        pretrained = False,
        distributed=False,
        use_blurpool=True,
        gpu = 0
    )
    model.load_state_dict(
        state_dict=torch.load(
            checkpoint_path,
            weights_only=True
        )
    )
    model.eval()
    return model

class ImageDataset:
    def __init__(
        self, image_filenames: list[str], labels: list[int], label_names: dict[int, str]
    ):
        self.image_filenames = image_filenames
        self.labels = labels
        self.label_names = label_names
        self.image_transform_fn = imagenet_transforms
        assert len(image_filenames) == len(labels)

    def __getitem__(self, idx: int):
        assert os.path.exists(self.image_filenames[idx])
        return {
            "image": self.image_transform_fn(Image.open(self.image_filenames[idx])),
            "label": self.labels[idx],
        }


class EvalSuite:
    def __init__(self, dataloader: Iterable):
        self.dataloader = dataloader
        sample = next(iter(dataloader))
        # raise AssertionError(sample)
        # assert isinstance(sample, tuple)
        assert torch.is_tensor(sample[0])
        assert torch.is_tensor(sample[1])

        assert sample[0].ndim == 4, f"Expected image batch to have 4 dims but got: {sample[0].ndim}"
        assert sample[1].ndim == 1, f"Expected image batch to have 1 dim but got: {sample[1].ndim}"
        assert sample[0].shape[0] == sample[1].shape[0], f"Expected batch sizes of images and labels to match. Sad sad sad"

    def get_hook_outputs(self, model, layer_names: list[str], max_num_batches = None, device="cuda:0", spatial_pooling: str = None, progress = False):

        hooks = {}
        hook_outputs = {}

        for layer_name in layer_names:
            hooks[layer_name] = ForwardHook(module=get_module_by_name(module=model, name=layer_name))
            hook_outputs[layer_name] = []

        model = model.half()
        num_batches = 0
        all_labels = []

        with torch.no_grad():
            for batch in tqdm(self.dataloader, disable=not(progress)):

                images, labels = batch
                logits = model(images.half().to(device))
                appended = labels.float()
                all_labels.append(appended)


                for layer_name in layer_names:
                    output = hooks[layer_name].output
                    assert output is not None, f"Expected hook output to NOT be None for layer: {layer_name}"
                    assert torch.is_tensor(output), f"Expected hook output to be tensor but got: {output}"
                    output = output.float()
                    assert torch.isnan(output).all() == False
                    if spatial_pooling is not None:
                        output = apply_spatial_pooling(output, spatial_pooling=spatial_pooling)
                    hook_outputs[layer_name].append(output.cpu().detach())


                num_batches += 1
                if max_num_batches is not None :
                    if num_batches == max_num_batches:
                        break
                    
        all_labels = torch.cat(
            all_labels, dim=0
        ).long()

        for key in hook_outputs:
            hook_outputs[key] = torch.cat(
                hook_outputs[key],
                dim =0
            )
        return hook_outputs, all_labels
    

    def collect_activations(
            self, hook_outputs: dict, layer_name: str, labels: torch.tensor, target_classes: list[int], other_classes : list, device: str
    ):
        hook_outputs = hook_outputs[layer_name].to(device)
        assert hook_outputs.ndim == 4
        assert labels.ndim == 1
        assert labels.shape[0] == hook_outputs.shape[0]
        labels = labels.to(hook_outputs.device)

        # Calculate the mean across spatial dimensions (height * width)
        """
        OK MAYBE THIS WAS THE MISTAKE
        I WAS TAKING THE NORM
        """
        # print("NOW TAKING MEAN AND NOT NORM ALONG HEIGHT WIDTH")
        from einops import reduce
        channel_activations = reduce(
            hook_outputs,
            "batch channel height width -> batch channel",
            reduction="mean"
        )

        assert channel_activations.ndim == 2
        assert channel_activations.shape[0] == labels.shape[0]

        # Filter activations by target class and other classes
        target_activations = channel_activations[torch.isin(labels, torch.tensor(target_classes).to(device))]  # shape: [target_batch, channel]
        other_activations = channel_activations[torch.isin(labels, torch.tensor(other_classes).to(device))]  # shape: [other_batch, channel]

        return target_activations, other_activations
    
    def collect_mean_activations(self, hook_outputs: dict, layer_name: str, labels: torch.tensor, target_classes: list[int], other_classes : list, device: str):
        
        target_activations, other_activations = self.collect_activations(
            hook_outputs=hook_outputs,
            layer_name=layer_name,
            labels=labels,
            target_classes=target_classes,
            other_classes=other_classes,
            device=device
        )
        # Mean activations for each channel
        mean_target_activation = target_activations.mean(dim=0)  # shape: [channel]
        mean_other_activation = other_activations.mean(dim=0)  # shape: [channel]
        return mean_target_activation, mean_other_activation
    
    def compute_selectivity_all_channels(self, hook_outputs: dict, layer_name: str, labels: torch.tensor, target_classes: list[int], other_classes : list, device: str):
        
        mean_target_activation, mean_other_activation = self.collect_mean_activations(
            hook_outputs=hook_outputs,
            layer_name=layer_name,
            labels=labels,
            target_classes=target_classes,
            other_classes=other_classes,
            device=device
        )

        # Compute selectivity for all channels at once
        selectivity = (mean_target_activation - mean_other_activation) / (
            mean_target_activation + mean_other_activation
        )
        return selectivity.cpu().numpy()  # Convert to numpy if needed
    
    def compute_tvals(self, hook_outputs: dict, layer_name: str, labels: torch.tensor, target_classes: list[int], other_classes : list, device: str):
        
        """
        target_activation.shape: num_dataset_samples, num_output_neurons
        same for other_activation
        """
        target_activation, other_activation = self.collect_activations(
            hook_outputs=hook_outputs,
            layer_name=layer_name,
            labels=labels,
            target_classes=target_classes,
            other_classes=other_classes,
            device=device
        )
        from scipy import stats

        t_statistic, p_value = stats.ttest_ind(target_activation.cpu(), other_activation.cpu(), equal_var=False)

        # raise AssertionError(target_activation.shape, other_activation.shape, t_statistic.shape, p_value.shape)
        return t_statistic, p_value
    
    def compute_dprime_all_channels(self, hook_outputs: dict, layer_name: str, labels: torch.tensor, target_classes: list[int], other_classes : list, device: str):

        mean_target_activations, mean_other_activations = self.collect_mean_activations(
            hook_outputs=hook_outputs,
            layer_name=layer_name,
            labels=labels,
            target_classes=target_classes,
            other_classes=other_classes,
            device=device
        )
        dprime = (mean_target_activations - mean_other_activations) / (mean_target_activations.var()+mean_other_activations.var())**0.5
        return dprime.cpu().numpy()

    
    def compute_delta_all_channels(self, hook_outputs: dict, layer_name: str, labels: torch.tensor, target_classes: list[int], other_classes : list, device: str):
        
        mean_target_activation, mean_other_activation = self.collect_mean_activations(
            hook_outputs=hook_outputs,
            layer_name=layer_name,
            labels=labels,
            target_classes=target_classes,
            other_classes=other_classes,
            device=device
        )

        # Compute selectivity for all channels at once
        selectivity = (mean_target_activation - mean_other_activation)

        return selectivity.cpu().numpy()  # Convert to numpy if needed

    
    def compute_accuracy(self, model, max_num_batches = None, progress = False):

        model = model.half()
        correct = 0
        total = 0
        num_batches = 0
        
        with torch.no_grad():
            for batch in tqdm(self.dataloader, disable=not(progress)):
                images, labels = batch
                
                logits = model(images)  # Get model predictions
                predicted = torch.argmax(logits, dim=1)  # Get the predicted class by taking the argmax along the class dimension
                
                correct += (predicted == labels).sum().item()  # Count correct predictions
                total += labels.size(0)  # Keep track of total number of samples
                num_batches += 1
                if max_num_batches is not None :
                    if num_batches == max_num_batches:
                        break
                
        accuracy = correct / total  # Calculate accuracy
        torch.cuda.empty_cache()
        del images
        del labels
        del logits
        return accuracy
            

    def compute_classwise_accuracy(self, model, num_classes = 1000, max_num_batches=None):

            model = model.half()
            class_correct = [0] * num_classes  # List to store correct predictions for each class
            class_total = [0] * num_classes    # List to store total predictions for each class
            num_batches = 0

            with torch.no_grad():
                for batch in tqdm(self.dataloader):
                    images, labels = batch

                    logits = model(images)  # Get model predictions
                    predicted = torch.argmax(logits, dim=1)  # Get predicted class

                    # Iterate over batch and accumulate class-wise correct/total predictions
                    for i in range(labels.size(0)):
                        label = labels[i].item()
                        pred = predicted[i].item()
                        if pred == label:
                            class_correct[label] += 1
                        class_total[label] += 1

                    num_batches += 1
                    if max_num_batches is not None:
                        if num_batches == max_num_batches:
                            break

            # Calculate class-wise accuracy as a dictionary
            class_accuracy = {i: class_correct[i] / class_total[i] if class_total[i] > 0 else 0 for i in range(num_classes)}

            return class_accuracy
    


def downsample_resnet(model, layer_names: list[str], downsample_factor = 9.0, max_loss = None):
    
    for layer_name in layer_names:
        original_layer = get_module_by_name(module=model, name=layer_name)

        if max_loss is  None:
                downsampled_layer = DownsampledConv2d(
                    conv_layer=original_layer,
                    factor_h=math.sqrt(downsample_factor),
                    factor_w=math.sqrt(downsample_factor)
                )
                original_layer.cpu()
                del original_layer
                setattr_pytorch_model(
                    model=model,
                    name=layer_name,
                    item=downsampled_layer
                )
        else:
            laplacian_pyramid_loss = LaplacianPyramidLoss(
                layer=original_layer,
                device=original_layer.weight.device,
                factor_h=[math.sqrt(downsample_factor)],
                factor_w=[math.sqrt(downsample_factor)]
            )
            loss = laplacian_pyramid_loss.get_loss().item()

            if loss < max_loss:
                downsampled_layer = DownsampledConv2d(
                    conv_layer=original_layer,
                    factor_h=math.sqrt(downsample_factor),
                    factor_w=math.sqrt(downsample_factor)
                )
                setattr_pytorch_model(
                    model=model,
                    name=layer_name,
                    item=downsampled_layer
                )
            else:
                print(f"Not downsampling layer: {layer_name} because it has a high topo loss ({loss} > {max_loss})")
    
    return model

class ResnetEfficiencyEval:
    def __init__(
        self,
        val_dataset_path = "/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv",
        batch_size = 128,
        checkpoints_folder: str = "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints",
        device = "cuda:0",
        mode = "resnet18"
    ):
        assert mode in ["resnet18", "resnet50"]
        self.checkpoints_folder = checkpoints_folder 
        self.val_dataset_path = val_dataset_path
        self.batch_size=batch_size
        self.device = device
        self.mode = mode

    def load_resnet(self, model_name, epoch):

        if model_name == "pretrained":
            if self.mode == "resnet18":
                return models.resnet18(weights="DEFAULT").eval()
            elif self.mode == "resnet50":
                return models.resnet50(weights="DEFAULT").eval()

        else:
            if self.mode == "resnet18":
                model = load_resnet18_checkpoint(
                    checkpoints_folder= self.checkpoints_folder,
                    model_name=model_name,
                    epoch=epoch
                )
            else:
                model = load_resnet50_checkpoint(
                    checkpoints_folder= self.checkpoints_folder,
                    model_name=model_name,
                    epoch=epoch
                )
        return model

    def downsample_eval(self, model_name, epoch, topo_layer_names, downsample_factor, max_num_batches, progress):
        torch.cuda.empty_cache()
        model = self.load_resnet(
            model_name=model_name,
            epoch=epoch
        )
        model = downsample_resnet(
            model = model,
            layer_names=topo_layer_names,
            downsample_factor=downsample_factor,
            max_loss = None
        ).to(self.device)
        model.to(self.device)
        
        val_acc = self.eval_suite.compute_accuracy(
            model=model,
            max_num_batches=max_num_batches,
            progress = progress
        )
        return val_acc, count_model_parameters(model=model)[0]

    @torch.no_grad()
    def run(
        self,
        model_name: str,
        downsample_factors: list[float],
        topo_layer_names: list[str],
        output_json_filename: str = "results.json",
        epoch: int = "final",
        max_num_batches = None,
        progress = True,
        run_downsampling_eval = False
    ):
        val_dataloader = create_val_loader(
            val_dataset=self.val_dataset_path,
            num_workers=16,
            batch_size=self.batch_size,
            resolution=224, 
            distributed=False, 
            gpu = 0
        )
        self.eval_suite = EvalSuite(
            dataloader=val_dataloader,
        )

        results= {}

        model = self.load_resnet(
            model_name=model_name,
            epoch=epoch
        )
        model.eval().to(self.device)
        
        val_acc = self.eval_suite.compute_accuracy(
            model=model,
            max_num_batches=max_num_batches,
            progress = progress
        )
        del model
        print(f"[Original] Model: {model_name} Val acc: {val_acc}")

        if run_downsampling_eval:
            results["downsampling"] = []
        results["l1"] = []

        for downsample_factor in downsample_factors:

            if run_downsampling_eval:
                val_acc, parameters = self.downsample_eval(
                    model_name=model_name,
                    topo_layer_names=topo_layer_names,
                    downsample_factor=downsample_factor,
                    max_num_batches=max_num_batches,
                    progress=progress,
                    epoch=epoch
                )
                print(f"[{model_name} down {downsample_factor}] acc: {val_acc}")
                results["downsampling"].append(
                    {
                        "downsample_factor": downsample_factor,
                        "val_acc": val_acc,
                        "parameters": parameters,
                        "downsample_factor": downsample_factor
                    }
                )
                
            fraction_of_masked_weights = (100 - (100/downsample_factor))/100
            model = self.load_resnet(
                model_name=model_name,
                epoch=epoch
            )
            torch.cuda.empty_cache()
            model, total_num_masked_weights = apply_l1_sparsity_to_model(model=model, fraction_of_masked_weights=fraction_of_masked_weights, layer_names=topo_layer_names, return_num_masked_weights=True)
            model.eval().to(self.device)

            val_acc = self.eval_suite.compute_accuracy(
                model=model,
                max_num_batches=max_num_batches,
                progress=progress
            )
            parameters = count_model_parameters(model=model)[0]
            del model

            results["l1"].append(
                {
                    "fraction_of_masked_weights": fraction_of_masked_weights,
                    "val_acc": val_acc,
                    "parameters": parameters,
                    "sparsity": fraction_of_masked_weights,
                    "downsample_factor": downsample_factor
                }
            )
            
            print(f"[{model_name} l1 {downsample_factor}] acc: {val_acc}")
            torch.cuda.empty_cache()

        del val_dataloader
        dict_to_json(
            results,
            output_json_filename
        )
        print(f"Saved: {output_json_filename}")