import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.table import Table
import shutil
from torch.utils.data import DataLoader
from IPython.display import display
from PIL import Image
import numpy as np

from src.data_ncb.precompute_encodings import EncodingPrecomputer
from src.data_ncb.argparser_precompute_encodings import get_parser
from src.data_ncb.CLEVR_Hans_image_dataset import CLEVRHansDataset
from src.trainer.trainer_framework import BaseTrainer
from src.pixel_space.trainer_framework_pixel import BasePixelSpaceTrainer
from src.config.config_dataclass import TrainerConfig, DatasetConfig, ModelConfig, BooleanConfig, FeatureSelectorConfig

class FeatureAnalyzer(BaseTrainer):
    """Class for analyzing feature selection in Merlin-Arthur models"""
    
    def __init__(
        self, 
        trainer_config: TrainerConfig,
        dataset_config: DatasetConfig,
        model_config: ModelConfig, 
        bool_config: BooleanConfig,
        feature_selector_config: FeatureSelectorConfig,
        num_slots,
        num_blocks,
        input_dim,
        image_folder
    ):
        """Initialize feature analyzer with complete configuration setup
        
        Args:
            trainer_config: Configuration for training parameters
            dataset_config: Configuration for dataset parameters
            model_config: Configuration for model parameters
            bool_config: Configuration for boolean parameters
            feature_selector_config: Configuration for feature selector parameters
            logger: Logger for logging analysis metrics
        """
        # Initialize the BaseTrainer with the provided configurations
        super().__init__(
            trainer_config=trainer_config,
            dataset_config=dataset_config,
            model_config=model_config,
            bool_config=bool_config,
            feature_selector_config=feature_selector_config,
            logger=None
        )

        self.num_slots = num_slots
        self.num_blocks = num_blocks
        self.input_dim = input_dim
        self.image_folder = image_folder
    
    def setup_data(self):
        """Loads the data from the dataset."""
        # Create datasets with wrapper

        _, self.val_loader, self.test_loader = self.precomputer.load_images()
        self.num_classes = self.val_loader.dataset.n_classes

    def setup_ncb(self):
        """Sets up the NCB model."""
        # Get the parser and create arguments
        parser = get_parser()
        
        # Parse empty args list (to avoid reading command line args)
        args = parser.parse_args([])
        
        args.enc_type = 'concept_slot'
        args.data_dir = self.data_dir
        args.sysbinder_path = "path/to/your/models/ncb/trainedmodels_NCBrepo/CLEVR-4/retbind_seed_0/best_model.pt"
        args.result_dir = "tmp"  # Temporary directory that will be deleted
        args.retrieval_corpus_path = "path/to/your/models/ncb/trainedmodels_NCBrepo/CLEVR-4/retbind_seed_0/block_concept_dicts.pkl"
        args.num_workers = self.num_workers
        args.batch_size = self.batch_size
        
        # Initialize the precomputer
        self.precomputer = EncodingPrecomputer(args)

        if os.path.exists(args.result_dir):
            shutil.rmtree(args.result_dir)
        
        self.ncb = self.precomputer.model
        self.sysbinder = self.ncb.model
    
    def feature_analysis(self, split='val'):
        """Analyzes example images to investigate which concepts in the encodings
        were chosen by Merlin and Morgana.
        """
        self.model.eval()
        self.merlin.eval()
        self.morgana.eval()

        if split == 'val':
            loader = self.val_loader
        elif split == 'test':
            loader = self.test_loader
        else:
            raise ValueError(f"Invalid split: {split}")
        
        # Extract one batch from the loader
        images, _, labels, fnames = next(iter(loader))
        images = images.to(self.device)
        labels = labels.to(self.device).long()

        # Get encodings and predictions
        bs_encs, cs_encs = self._get_encodings(images)
        one_hot_padded_encs = self._convert_into_one_hot_padded_encs(cs_encs)
        masked_inputs_merlin, masked_inputs_morgana, preds_merlin, preds_morgana = self.evaluate(one_hot_padded_encs, labels)

        # Extract selected features
        features_merlin, features_morgana = self._extract_selected_features(masked_inputs_merlin, masked_inputs_morgana)
        
        # Store the results as instance variables for easier access
        self.analysis_results = {
            'bs_encs': bs_encs,
            'cs_encs': cs_encs,
            'fnames': fnames,
            'labels': labels,
            'preds_merlin': preds_merlin,
            'preds_morgana': preds_morgana,
            'features_merlin': features_merlin,
            'features_morgana': features_morgana
        }
        
        return bs_encs, cs_encs, fnames, labels, preds_merlin, preds_morgana, features_merlin, features_morgana

    def _get_encodings(self, images):
        """Gets block-slot and concept-slot encodings for the images."""
        bs_encs, _, _, _ = self.sysbinder.encode(images)
        representations = torch.stack([self.ncb.retrieve_discrete_representation(s) for s in bs_encs])
        cs_encs = representations[..., 0]
        return bs_encs, cs_encs

    def _extract_selected_features(self, masked_inputs_merlin, masked_inputs_morgana):
        """Extracts the selected features from masked inputs."""
        # Reshape masked inputs
        enc_masked_merlin = masked_inputs_merlin.view(-1, self.num_slots, self.num_blocks, masked_inputs_merlin.shape[-1]//self.num_blocks)
        enc_masked_morgana = masked_inputs_morgana.view(-1, self.num_slots, self.num_blocks, masked_inputs_morgana.shape[-1]//self.num_blocks)

        # Find non-zero elements (selected features)
        has_one_merlin = torch.nonzero(enc_masked_merlin, as_tuple=False)
        has_one_morgana = torch.nonzero(enc_masked_morgana, as_tuple=False)

        has_one_merlin = has_one_merlin.view(-1, self.mask_size, 4)
        has_one_morgana = has_one_morgana.view(-1, self.mask_size, 4)

        # Delete first column (batch index)
        features_merlin = has_one_merlin[:, :, 1:]
        features_morgana = has_one_morgana[:, :, 1:]
        
        return features_merlin, features_morgana

    def _convert_into_one_hot_padded_encs(self, cs_encs):
        one_hot_padded_encs = []
        block_size = max(self.ncb.prior_num_concepts)
        for enc in cs_encs:
            one_hot_padded_enc = []
            for slot in enc:
                slot = slot.to(torch.long)
                one_hot_per_slot = nn.functional.one_hot(slot, num_classes=block_size)
                one_hot_padded_enc.append(one_hot_per_slot)
            one_hot_tensor = torch.stack(one_hot_padded_enc, dim=0)
            one_hot_padded_encs.append(one_hot_tensor)
        one_hot_padded_encs = torch.stack(one_hot_padded_encs, dim=0)
        one_hot_padded_encs = one_hot_padded_encs.view(self.batch_size, self.num_slots, -1).to(self.device).float()

        return one_hot_padded_encs
    
    def evaluate(self, one_hot_padded_encs, labels):
        if self.approach == "sfw":
            # SFW needs gradients for mask optimization
            with torch.enable_grad():
                continuous_mask_merlin = self.merlin(one_hot_padded_encs, labels, self.model)
                continuous_mask_morgana = self.morgana(one_hot_padded_encs, labels, self.model)
        else:  # learn_fs
            with torch.no_grad():
                continuous_mask_merlin = self.merlin(one_hot_padded_encs)
                continuous_mask_morgana = self.morgana(one_hot_padded_encs)

        with torch.no_grad():
            # Convert to binary masks using top-k selection
            binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)
        
            # Apply masks and get predictions
            masked_inputs_merlin = self.merlin.apply_mask(one_hot_padded_encs, binary_mask_merlin)
            masked_inputs_morgana = self.morgana.apply_mask(one_hot_padded_encs, binary_mask_morgana)
            
            logits_merlin = self.model(masked_inputs_merlin)
            logits_morgana = self.model(masked_inputs_morgana)

            _, preds_merlin = torch.max(logits_merlin, 1)
            _, preds_morgana = torch.max(logits_morgana, 1)

        return masked_inputs_merlin, masked_inputs_morgana, preds_merlin, preds_morgana

    def plot_results(self, i, all_images=False):
        """Visualizes the results from the concept feature analysis.
        
        Args:
            i: Index of the example to visualize
            all_images: If True, additionally outputs all individual images at their original size
                without any titles or annotations (useful for paper figures)
        """
        bs_encs = self.analysis_results['bs_encs']
        cs_encs = self.analysis_results['cs_encs']
        fnames = self.analysis_results['fnames']
        labels = self.analysis_results['labels']
        preds_merlin = self.analysis_results['preds_merlin']
        preds_morgana = self.analysis_results['preds_morgana']
        features_merlin = self.analysis_results['features_merlin']
        features_morgana = self.analysis_results['features_morgana']
        
        # Print basic information
        print(f'File: {fnames[i]}')
        print(f'Class: {labels[i].item()}')
        print(f'Prediction for merlin: {preds_merlin[i].item()}')
        print(f'Prediction for morgana: {preds_morgana[i].item()}')
        
        # Extract data for the specific example
        encs = cs_encs[i].clone().to(torch.int)
        features_merlin_i = features_merlin[i].clone().tolist()
        features_morgana_i = features_morgana[i].clone().tolist()
        
        # Generate visualizations
        self.plot_original_image(i, fnames)
        self.plot_cs_encodings(encs)
        self.plot_reconstructed_images(self.get_reconstructed_images(i, bs_encs))
        self.plot_feature_selection(features_merlin_i, features_morgana_i)
        self.plot_clustered_exemplars(features_merlin_i, title="Clustered exemplars for Merlin's features")
        self.plot_clustered_exemplars(features_morgana_i, title="Clustered exemplars for Morgana's features")

        # If all_images flag is set, output individual high-quality images
        if all_images:
            def show_image_without_margins(img):
                from IPython.display import display
                from PIL import Image
                import numpy as np
                
                # Convert to PIL Image
                if isinstance(img, np.ndarray):
                    if img.dtype == np.float32 or img.dtype == np.float64:
                        img = (img * 255).astype(np.uint8)
                    img = Image.fromarray(img)
                display(img)

            # Original image
            show_image_without_margins(mpimg.imread(fnames[i]))
            
            # Plot reconstructions individually
            images = self.get_reconstructed_images(i, bs_encs)
            for img in images:
                show_image_without_margins(img)
            
            # Plot exemplar images individually
            features_merlin_i = features_merlin[i].clone().tolist()
            features_morgana_i = features_morgana[i].clone().tolist()
            
            # Get unique block-concept pairs for Merlin
            block_concept_merlin = [row[1:] for row in features_merlin_i]
            block_concept_merlin = list(map(list, set(map(tuple, block_concept_merlin))))
            block_concept_merlin = sorted(block_concept_merlin, key=lambda x: (x[0], x[1]))
            
            # Plot Merlin's exemplars
            for pair in block_concept_merlin:
                block, concept = pair
                image_filename = f"block{block}_{concept}.png"
                image_path = os.path.join(self.image_folder, image_filename)
                if os.path.exists(image_path):
                    show_image_without_margins(mpimg.imread(image_path))
            
            # Get unique block-concept pairs for Morgana
            block_concept_morgana = [row[1:] for row in features_morgana_i]
            block_concept_morgana = list(map(list, set(map(tuple, block_concept_morgana))))
            block_concept_morgana = sorted(block_concept_morgana, key=lambda x: (x[0], x[1]))
            
            # Plot Morgana's exemplars
            for pair in block_concept_morgana:
                block, concept = pair
                image_filename = f"block{block}_{concept}.png"
                image_path = os.path.join(self.image_folder, image_filename)
                if os.path.exists(image_path):
                    show_image_without_margins(mpimg.imread(image_path))

    def plot_original_image(self, i, fnames):
        """Plots the original image from the dataset.
        """
        img = mpimg.imread(fnames[i])
        plt.figure(figsize=(3, 3))
        plt.imshow(img)
        plt.title('Original image')
        plt.axis('off')  # Hide axis labels
        plt.show()

    def plot_cs_encodings(self, cs_encs):
        """Plots the concept-slot encodings for one image.
        """
        fig, ax = plt.subplots(figsize=(12, 2))
        ax.axis('off')
        ax.set_title("Concept-slot encoding", fontsize=14, pad=10)
        table = Table(ax, bbox=[0, 0, 1, 1])
        nrows, ncols = cs_encs.size()
        for j in range(ncols):
            table.add_cell(0, j + 1, 1, 1, text=f"Block {j}", loc='center', facecolor='lightgray')
        for i in range(nrows):
            table.add_cell(i + 1, 0, 1, 1, text=f"Slot {i}", loc='center', facecolor='lightgray')
        for i in range(nrows):
            for j in range(ncols):
                table.add_cell(i + 1, j + 1, 1, 1, text=str(cs_encs[i, j].item()), loc='center')
        for i in range(nrows + 1):
            for j in range(ncols + 1):
                table.auto_set_font_size(False)
                table.set_fontsize(8)
                table.scale(1.2, 1.2)
        ax.add_table(table)
        plt.show()

    def get_reconstructed_images(self, i, bs_encs, value=10000):
        """Reconstructs every object from the block-slot encodings."""
        bs_enc = bs_encs[i].unsqueeze(0)
        images = []
        
        # Generate reconstructions for each slot
        for slot_to_keep in range(4):
            slots_to_value = [j for j in range(4) if j != slot_to_keep]
            bs_enc_modified = bs_enc.clone()
            bs_enc_modified[:, slots_to_value, :] = value
            dec_image = self.sysbinder.decode(bs_enc_modified)
            dec_image = dec_image[0].to('cpu').permute(1, 2, 0).detach().numpy()
            images.append(dec_image)
        
        return images
    
    def plot_reconstructed_images(self, images):
        """Plots the reconstructed images per object.
        """
        titles = ['Slot 0', 'Slot 1', 'Slot 2', 'Slot 3']

        fig, axes = plt.subplots(1, 4, figsize=(12, 9))  # Adjust figsize for desired size

        for j, ax in enumerate(axes):
            ax.imshow(images[j])  # Replace cmap as needed
            ax.set_title(titles[j])
            ax.axis('off')  # Remove axes for a cleaner look

        fig.suptitle('Reconstructions per slot', fontsize=14, y=0.7)

        plt.tight_layout()  # Adjust spacing
        plt.show()

    def plot_feature_selection(self, features_merlin, features_morgana):
        """Plots the feature selection for Merlin and Morgana.
        """
        fig, axs = plt.subplots(1, 2, figsize=(10, 4))

        # Plot both tables
        self._plot_table(axs[0], features_merlin, ["Slots", "Blocks", "Concepts"], "Merlin features")
        self._plot_table(axs[1], features_morgana, ["Slots", "Blocks", "Concepts"], "Morgana features")

        plt.tight_layout()
        plt.show()

    def _plot_table(self, ax, data, headers, title):
        """Plots a table.
        """
        ax.axis('off')  # Turn off the axes
        table = Table(ax, bbox=[0, 0, 1, 1])
        nrows, ncols = len(data) + 1, len(headers)

        # Add headers
        for j, header in enumerate(headers):
            table.add_cell(0, j, 1, 1, text=header, loc='center', facecolor='lightgray')

        # Add data
        for k, row in enumerate(data):
            for j, cell in enumerate(row):
                table.add_cell(k + 1, j, 1, 1, text=str(cell), loc='center')
        
        ax.add_table(table)
        ax.set_title(title, fontsize=12)

    def plot_clustered_exemplars(self, features, title):
        """Plots the clustered exemplars for features of Merlin and Morgana.
        """
        # Prepare the data
        block_concept_merlin = [row[1:] for row in features]  # Remove slots
        block_concept_merlin = list(map(list, set(map(tuple, block_concept_merlin))))  # Remove duplicates
        block_concept_merlin = sorted(block_concept_merlin, key=lambda x: (x[0], x[1]))  # Sort

        fig, axs = plt.subplots(1, len(block_concept_merlin), figsize=(20, 8))

        # to work for mask size 1
        if len(block_concept_merlin) == 1:
            axs = [axs]

        if len(block_concept_merlin) > 4:
            fig.suptitle(title, fontsize=14, y=0.7)
        elif len(block_concept_merlin) < 3:
            fig.suptitle(title, fontsize=14, y=1)
        else:
            fig.suptitle(title, fontsize=14, y=0.8)

        # Loop through the sorted unique pairs and plot the images
        for idx, pair in enumerate(block_concept_merlin):
            block, concept = pair

            # Construct the image filename
            image_filename = f"block{block}_{concept}.png"
            image_path = os.path.join(self.image_folder, image_filename)
            
            # Check if the image exists
            if os.path.exists(image_path):
                img = mpimg.imread(image_path)
                axs[idx].imshow(img)
                axs[idx].axis('off')  # Hide axis

                # Set the main title above the image
                axs[idx].set_title(f"Block {block} Concept {concept}", fontsize=10)

                # Find the slots for this pair (block, concept)
                slots = self._find_slots_for_pair(features, block, concept)
                axs[idx].text(10, 420, f"Used for slot: {', '.join(map(str, slots))}", fontsize=10, color='black')
            else:
                axs[idx].axis('off')  # Hide axis if image is not found
                axs[idx].text(0.5, 0.5, "Image Not Found", ha='center', va='center', fontsize=12, color='red')

        # Show the plot
        plt.show()

    def _find_slots_for_pair(self, features, block, concept):
        """Finds which slots are used for the given block and concept.
        """
        slots = []
        for row in features:
            if row[1] == block and row[2] == concept:
                slots.append(row[0])
        return slots
    

class FeatureAnalyzerPixel(BasePixelSpaceTrainer):
    """Class for analyzing features in the pixel space case"""
    def __init__(
        self, 
        trainer_config: TrainerConfig,
        dataset_config: DatasetConfig,
        model_config: ModelConfig, 
        bool_config: BooleanConfig,
        feature_selector_config: FeatureSelectorConfig
    ):
        """
        Args:
            trainer_config: Configuration for the trainer
            dataset_config: Configuration for the dataset
            model_config: Configuration for the model
            bool_config: Configuration for boolean parameters
            feature_selector_config: Configuration for feature selection
            logger: Logger instance (default: None)
        """
        # Initialize the BaseTrainer with the provided configurations
        super().__init__(
            trainer_config=trainer_config,
            dataset_config=dataset_config,
            model_config=model_config,
            bool_config=bool_config,
            feature_selector_config=feature_selector_config,
            logger=None
        )

        self._load_images()
        self.setup_model()

    def _load_images(self):
        """Load and process image datasets"""

        print(f'Loading images from {self.data_dir}')

        conf_version = self.data_dir.split(os.path.sep)[-2]

        # Create datasets with wrapper
        dataset_train = CLEVRHansDataset(
            self.data_dir, "train", lexi=True, conf_vers=conf_version
        )
        dataset_val = CLEVRHansDataset(
            self.data_dir, "val", lexi=True, conf_vers=conf_version
        )
        dataset_test = CLEVRHansDataset(
            self.data_dir, "test", lexi=True, conf_vers=conf_version
        )

        # Create DataLoaders
        self.train_loader = DataLoader(
            dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        self.val_loader = DataLoader(
            dataset_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

        self.test_loader = DataLoader(
            dataset_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def feature_analysis(self, split='val'):
        """Analyzes example images to investigate which pixels
        were chosen by Merlin and Morgana.
        """

        self.model.eval()
        self.merlin.eval()
        self.morgana.eval()

        if split == 'val':
            loader = self.val_loader
        elif split == 'test':
            loader = self.test_loader
        else:
            raise ValueError(f"Invalid split: {split}")
        
        # Extract one batch from the loader
        images, _, labels, fnames = next(iter(loader))
        images = images.to(self.device)
        labels = labels.to(self.device).long()

        if self.approach == 'sfw':
            # Temporarily enable gradients for mask optimization
            with torch.enable_grad():
                # Optimize masks (in eval mode) - requires gradients
                continuous_mask_merlin = self.merlin(images, labels, self.model)
                continuous_mask_morgana = self.morgana(images, labels, self.model)
        elif self.approach == 'unet':
            # Disable gradients for the validation process
            with torch.no_grad():
                continuous_mask_merlin = self.merlin(images)
                continuous_mask_morgana = self.morgana(images)

                continuous_mask_merlin = self.merlin.normalize_l1(continuous_mask_merlin, self.mask_size)
                continuous_mask_morgana = self.morgana.normalize_l1(continuous_mask_morgana, self.mask_size)
        
        # Disable gradients for the rest of the validation process
        with torch.no_grad():
            # Convert to binary masks using top-k selection
            binary_mask_merlin = self.merlin.get_binary_mask(continuous_mask_merlin)
            binary_mask_morgana = self.morgana.get_binary_mask(continuous_mask_morgana)

            # Apply masks and get predictions
            masked_inputs_merlin = self.merlin.apply_mask(images, binary_mask_merlin)
            masked_inputs_morgana = self.morgana.apply_mask(images, binary_mask_morgana)
            
            logits_merlin = self.model(masked_inputs_merlin)
            logits_morgana = self.model(masked_inputs_morgana)

            _, preds_merlin = torch.max(logits_merlin, 1)
            _, preds_morgana = torch.max(logits_morgana, 1)

            #apply mask without random noise
            img_masked_merlin_no_noise = binary_mask_merlin * images
            img_masked_morgana_no_noise = binary_mask_morgana * images

        #store results
        self.analysis_results = {
            'fnames': fnames,
            'labels': labels,
            'preds_merlin': preds_merlin,
            'preds_morgana': preds_morgana,
            'masked_images_merlin': img_masked_merlin_no_noise.cpu(),
            'masked_images_morgana': img_masked_morgana_no_noise.cpu()
        }

        return fnames, labels, preds_merlin, preds_morgana, img_masked_merlin_no_noise.cpu(), img_masked_morgana_no_noise.cpu()

    def plot_results(self, i, all_images=False):
        """Visualizes the results from the pixel-space feature analysis.
        
        Args:
            i: Index of the example to visualize
            all_images: If True, additionally outputs all individual images at their original size
                without any titles or annotations (useful for paper figures)
        """
        fnames = self.analysis_results['fnames']
        labels = self.analysis_results['labels']
        preds_merlin = self.analysis_results['preds_merlin']
        preds_morgana = self.analysis_results['preds_morgana']
        masked_images_merlin = self.analysis_results['masked_images_merlin']
        masked_images_morgana = self.analysis_results['masked_images_morgana']

        # Print basic information
        print(f'File: {fnames[i]}')
        print(f'Class: {labels[i].item()}')
        print(f'Prediction for merlin: {preds_merlin[i].item()}')
        print(f'Prediction for morgana: {preds_morgana[i].item()}')

        fig, axes = plt.subplots(1, 3, figsize=(9, 3))  # 1 row, 3 columns
        
        # Original image
        img = mpimg.imread(fnames[i])

        axes[0].imshow(img)
        axes[0].set_title("Original image")
        axes[0].axis('off')  # Hide axis labels

        # Merlin mask
        axes[1].imshow(masked_images_merlin[i].permute(1, 2, 0).detach().numpy())
        axes[1].set_title("Merlin mask")
        axes[1].axis('off')  # Hide axis labels

        # Morgana mask
        axes[2].imshow(masked_images_morgana[i].permute(1, 2, 0).detach().numpy())
        axes[2].set_title("Morgana mask")
        axes[2].axis('off')  # Hide axis labels

        # Adjust layout for better spacing
        plt.tight_layout()
        plt.show()

        # If all_images flag is set, output individual high-quality images
        if all_images:
            def show_image_without_margins(img):
                from IPython.display import display
                from PIL import Image
                import numpy as np
                
                # Convert to PIL Image
                if isinstance(img, np.ndarray):
                    if img.dtype == np.float32 or img.dtype == np.float64:
                        img = (img * 255).astype(np.uint8)
                    img = Image.fromarray(img)
                display(img)

            # Original image
            show_image_without_margins(mpimg.imread(fnames[i]))
            
            # Merlin's masked image
            img = masked_images_merlin[i].permute(1, 2, 0).detach().numpy()
            show_image_without_margins(img)
            
            # Morgana's masked image
            img = masked_images_morgana[i].permute(1, 2, 0).detach().numpy()
            show_image_without_margins(img)