import torch
import numpy as np
import yaml
import argparse
from torch.autograd.functional import jacobian
from tqdm import tqdm

import pickle
import json
import os
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, Subset
from torch.utils.data import DataLoader, TensorDataset
from collections import defaultdict
import random
import matplotlib
matplotlib.use('Agg') # in environments like solar where we don't have plt.show(), not having this line might lead to errors

class RandomLabelDataset(Dataset):
    """Wrapper for dataset with completely random labels for each image"""
    def __init__(self, dataset, seed, num_samples=None, num_classes=10):
        self.dataset = dataset
        self.num_classes = num_classes
        
        # If num_samples is specified, limit the dataset size
        if num_samples is not None and num_samples < len(dataset):
            # Use seeded randomness for consistent subset selection
            rng = np.random.RandomState(seed)
            indices = rng.permutation(len(dataset))[:num_samples]
            self.dataset = Subset(dataset, indices)
        
        # Generate random labels for all images
        rng = np.random.RandomState(seed)
        self.random_labels = rng.randint(0, num_classes, len(self.dataset))
    
    def __getitem__(self, idx):
        
        image, _ = self.dataset[idx]  # Ignore original label
        return image, self.random_labels[idx]
    
    def __len__(self):
        return len(self.dataset)
    

def save_yaml(config, save_path):
    with open(save_path, 'w') as file:
        yaml.dump(config, file)

def load_yaml(file_path):
    with open(file_path, 'r') as file:
        return yaml.safe_load(file)

def dict_to_namespace(obj):
    """
    Recursively converts a dictionary (or list of dictionaries)
    into an argparse.Namespace, so you can do dot-based access.
    """
    if isinstance(obj, dict):
        ns = argparse.Namespace()
        for key, value in obj.items():
            setattr(ns, key, dict_to_namespace(value))
        return ns
    elif isinstance(obj, list):
        # If the list has dicts, convert each dict into a namespace.
        # Otherwise, leave the item as-is.
        return [dict_to_namespace(x) for x in obj]
    else:
        # Base case: string, int, float, bool, etc.
        return obj

def namespace_to_dict(ns):
    """
    Recursively convert an argparse.Namespace (potentially nested)
    back into a dictionary, so we can dump it into YAML.
    """
    result = {}
    for key, value in vars(ns).items():
        if isinstance(value, argparse.Namespace):
            result[key] = namespace_to_dict(value)
        elif isinstance(value, list):
            # Convert each element if it’s a namespace
            result[key] = [
                namespace_to_dict(x) if isinstance(x, argparse.Namespace) else x
                for x in value
            ]
        else:
            result[key] = value
    return result


def str2bool(v):
    # function for handling boolean inputs as the arguments to our program
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def load_config(config_path):
    """
    Load YAML from config_path and convert it into a nested argparse.Namespace
    for dot-notation access (e.g., config.training.dataset).
    """
    raw_dict = load_yaml(config_path)
    return dict_to_namespace(raw_dict)

def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)

def save_file(dir_path, file_name, content = None, **kwargs):
    """
        a function that saves our file when we want to store the results in a specific place.
        Args:
            dir_path: the directory where we want to save our file
            file_name: the name of our file
            content: the content we want to save in our file
            apply_cc: explicitly telling the address to be based on compute canada
            kwargs: 
                if you want to np.savez_compressed, you should provide compress = True, feautures, labels
                you can also pass the protocol for pickle (pickle_protocol)
        """
    file_encoding = file_name.split('.')[-1]

    # if we are on compute canada, we want our save path to be on scratch (only for files in the data directory, and not for plots and figures)
    

    
    print(f'saving {file_name} ...')
    if file_encoding in ['npz', 'pkl', 'png', 'npy', 'txt', 'pickle', 'pth', 'json', 'csv']:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path) #creating the dir if it does'nt already exist
        
        save_dir = os.path.join(dir_path, file_name)
        
        if file_encoding == 'npz':
            print(f"Saving file as NPZ format to {save_dir}")
            if 'compress' in kwargs:
                np.savez_compressed(save_dir, features = kwargs['features'], labels = kwargs['labels'])
            else:
                if isinstance(content, dict):
                    np.savez(save_dir, **content)
                else:
                    np.savez(save_dir, *content)

        elif file_encoding in ['pkl', 'pickle']:
            print(f"Saving file as pickle format to {save_dir}")
            with open(save_dir, 'wb') as file:
                if 'pickle_protocol' not in kwargs:
                    pickle.dump(content, file)
                else:
                    pickle.dump(content, file, protocol= kwargs['pickle_protocol'])

        elif file_encoding == 'png':
            print(f"Saving plot as PNG image to {save_dir}") 
            plt.savefig(save_dir) #TODO check if this works if we call plt.savefig out of nowhere
            plt.close()

        elif file_encoding == 'npy':
            print(f"Saving NumPy array to {save_dir}")
            np.save(save_dir, content)

        elif file_encoding == 'txt':
            print(f"Saving text file to {save_dir}")
            with open(save_dir, 'w') as file:
                file.write(content)

        elif file_encoding == 'pth':
            print(f"Saving PyTorch model state to {save_dir}") # content is expected to be wf_model.state_dict(
            torch.save(content, save_dir)

        elif file_encoding == 'json':
            print(f"Saving JSON file to {save_dir}")
            with open(save_dir, 'w') as f:
                json.dump(content, f, indent=4)

        elif file_encoding == 'csv':
            print(f"Saving DataFrame as CSV to {save_dir}")
            content.to_csv(save_dir)  #we assume that content is pd df and save it accordingly
        

        print(f'{file_name} saved')

def compute_jacobian(model, images, labels):
    """
    Computes the Jacobian matrix of a neural network model with respect to its input images.
    The Jacobian represents the partial derivatives of the model's output logits with respect
    to each input pixel, specifically for the correct class labels.

    Parameters:
        model (torch.nn.Module): Neural network model for which to compute the Jacobian.
                                Must have a method to compute logits for the correct class.
        images (torch.Tensor): Input images tensor of shape [batch_size, num_pixels].
                              Typically flattened images where num_pixels = height * width.
        labels (torch.Tensor): Ground truth labels corresponding to the input images.
                              Used to compute the logits for the correct class.

    Returns:
        torch.Tensor: Jacobian matrix of shape [batch_size, num_pixels] where each row
                     contains the partial derivatives of the correct-class logit with
                     respect to each input pixel for one image.

    Note:
        The function first computes a full Jacobian of shape [batch_size, batch_size, num_pixels]
        where each slice [i,:,:] represents the derivatives of all outputs with respect to
        input image i. It then extracts only the relevant derivatives (those of each image's
        correct-class logit with respect to its own pixels) by selecting the diagonal elements.
    """
    model.label = labels
    with torch.no_grad():
        J = jacobian(model, images) # this will be torch.Size([48, 48, 784])
        '''
        Because your model returns a vector of size 48 (one value per image, i.e., the correct-class logit), 
        and your input images is [48, 784], the output of jacobian (i.e., J) will typically have shape [48, 48, 784] or a similarly nested structure.
        '''

    
    J = J[range(images.size(0)), range(images.size(0)), :] # this will be torch.Size([48, 784]) -  picks out the "diagonal" along the first two dimensions.

    return J

    
def process_jacobian__in_batches(model, train_loader, criterion, device, sample_limit = None, batch_size=48, verbose = False):
    """
    Process data in batches to compute outputs, loss, and Jacobian
    
    Args:
        model: PyTorch model
        train_loader: DataLoader containing training data
        criterion: Loss function
        device: Device to run computations on
        sample_limit: a limit on number of training instances
        batch_size: Size of batches for Jacobian computation
    """
    all_outputs = []
    all_jacobians = []
    all_images = []
    all_labels = []
    
    # First, collect all data
    with torch.no_grad():  # No need for gradients during collection
        sample_counter = 0
        for images, labels in tqdm(train_loader, desc= 'accumulating all the data for overall Jacobian Computation', disable= not verbose):
            images = images.to(device)
            labels = labels.to(device)
            all_images.append(images)
            all_labels.append(labels)
            sample_counter += images.size()[0]
            if sample_limit is not None and sample_counter >= sample_limit:
                break
    
    # Concatenate all data
    all_images = torch.cat(all_images, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    # Process in batches
    num_samples = len(all_images)
    for i in tqdm(range(0, num_samples, batch_size), desc= 'Computing Jacobian for each Batch', disable= not verbose):
        batch_images = all_images[i:i + batch_size]
        batch_labels = all_labels[i:i + batch_size]
        batch_images = batch_images.view(batch_images.size(0), -1).requires_grad_()
        # Compute outputs
        batch_outputs = model.sudo_forward(batch_images)
        all_outputs.append(batch_outputs)
        
        # Compute Jacobian for this batch
        batch_jacobian = compute_jacobian(
            model=model,
            images=batch_images,
            labels=batch_labels
        )
        all_jacobians.append(batch_jacobian)
    
    # Concatenate results
    all_outputs = torch.cat(all_outputs, dim=0)
    overall_jacobian = torch.cat(all_jacobians, dim=0)
    
    # Compute overall loss
    overall_loss = criterion(all_outputs, all_labels)
    
    # Update model statistics
    model.update_statistic(all_images, all_outputs, overall_loss, all_labels, overall_jacobian)
    
    return all_images, all_labels, all_outputs, overall_loss, overall_jacobian

##### Dataset creation functions

def get_limited_dataset(dataset, limit, seed):
    """Helper function to get a random subset of the dataset"""
    if limit and limit < len(dataset):
        rng = np.random.RandomState(seed)
        indices = rng.permutation(len(dataset))[:limit]
        return Subset(dataset, indices)
    return dataset


def flatten_image(img):
    """
    Given a PIL image or ndarray of shape (H, W) or (H, W, C),
    flatten it to shape (H*W*C,).
    """
    np_img = np.array(img)
    return np_img.reshape(-1)

def apply_label_permutation(dataset, label_perm,  flatten = True):
    """
    Given a dataset (list of (data, label) or TorchVision dataset),
    apply a label permutation such that old_label -> label_perm[old_label].
    Return a new TensorDataset with the same data but permuted labels.
    """
    perm_data = []
    perm_targets = []
    
    for i in range(len(dataset)):
        img, label = dataset[i]
        # Flatten the image (assuming dataset is image classification)
        if flatten:
            flat_img = flatten_image(img)
        else:
            flat_img = img
        # Re-map label
        new_label = label_perm[label]
        perm_data.append(flat_img)
        perm_targets.append(new_label)
    
    perm_data = torch.tensor(np.array(perm_data), dtype=torch.float32)
    perm_targets = torch.tensor(perm_targets, dtype=torch.long)
    return TensorDataset(perm_data, perm_targets)

def permute_mnist_dataset(dataset, permutation):
    """
    Apply a given pixel permutation (for MNIST) to each image in the dataset.
    Returns a TensorDataset of permuted inputs + original labels.
    """
    permuted_data = []
    permuted_targets = []
    
    for i in range(len(dataset)):
        img, target = dataset[i]
        # Flatten the image
        flat_img = flatten_image(img)  # shape = [784] for MNIST
        
        # Apply pixel permutation
        flat_img = flat_img[permutation].reshape(1, 28, 28)
        permuted_data.append(flat_img)
        permuted_targets.append(target)
    permuted_data = torch.tensor(np.array(permuted_data), dtype=torch.float32)
    permuted_targets = torch.tensor(np.array(permuted_targets), dtype=torch.long)
    return TensorDataset(permuted_data, permuted_targets)

def get_label_permutation(num_classes=10, seed=0):
    """
    Generate a permutation of the labels [0..num_classes-1].
    """
    rng = np.random.RandomState(seed)
    return rng.permutation(num_classes)


def get_permutation(seed=0, size=784):
    """
    Generate a permutation of `size` elements for permuted MNIST.
    """
    rng = np.random.RandomState(seed)
    return rng.permutation(size)

def create_balanced_dataset(dataset, samples_per_class=None, seed=43):
    """
    Creates a balanced dataset with a specified number of samples per class.
    
    Args:
        dataset: The original dataset (e.g., CIFAR-10)
        samples_per_class: Number of samples to keep per class. If None, keeps all samples.
        seed: Random seed for reproducible sampling. If None, sampling will be random.
    
    Returns:
        Subset of the original dataset with balanced classes
    """
    if samples_per_class is None:
        return dataset
    
    # Set random seed if provided
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
    
    # Get all targets
    targets = torch.tensor(dataset.targets)
    
    # Create dictionary of indices for each class
    class_indices = defaultdict(list)
    for idx, label in enumerate(targets):
        class_indices[label.item()].append(idx)
    
    # Sample equal number of indices from each class
    balanced_indices = []
    for class_label in sorted(class_indices.keys()):  # Sort keys for reproducibility
        indices = class_indices[class_label]
        # Randomly sample from this class
        selected_indices = np.random.choice(
            indices, 
            size=min(samples_per_class, len(indices)), 
            replace=False
        )
        balanced_indices.extend(selected_indices)
    
    # Create a subset using the balanced indices
    return Subset(dataset, balanced_indices)


def plot_params_histogram(params, bins=50, title='C Parameter Distribution', 
                         xlabel='Parameter Value', ylabel='Frequency',
                         figsize=(10, 6), density=False):
    """
    Plot a histogram of parameter values with customizable options.
    
    Args:
        params (list or array): List of parameter values to plot
        bins (int): Number of bins for the histogram
        title (str): Title of the plot
        xlabel (str): Label for x-axis
        ylabel (str): Label for y-axis
        figsize (tuple): Figure size as (width, height)
        density (bool): If True, plot density instead of frequency
    """
    plt.figure(figsize=figsize)
    
    # Plot histogram
    plt.hist(params, bins=bins, density=density, alpha=0.7, color='blue', edgecolor='black')
    
    # Add mean and std lines
    mean = np.mean(params)
    std = np.std(params)
    ymin, ymax = plt.ylim()
    plt.axvline(mean, color='red', linestyle='dashed', linewidth=2, label=f'Mean: {mean:.3f}')
    plt.axvline(mean + std, color='green', linestyle='dashed', linewidth=1, label=f'Mean ± Std')
    plt.axvline(mean - std, color='green', linestyle='dashed', linewidth=1)
    
    # Customize plot
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    return plt.gcf()  # Return the figure object














