import logging
import math

import numpy as np
import torch
import wandb
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure


from . import params


def gratings_orientation_analysis(model, wandb_run, logger: logging.Logger):
    """
    Generates gratings, computes per-z orientation tuning to find most responsive
    angles (NaN for non-tuned), then bins z components by preferred angle and
    logs an average response by orientation plot.

    Reads params.analysis_params.gratings_orientation_analysis if present with defaults:
      spatial_frequency=3, contrast=1, bin_size=5, method='threshold', decay_threshold=0.25
    """
    logger.info('Running Gratings Orientation Analysis')
    try:
        from task_vae.utils import (
            generate_gratings,
            calculate_avg_abs_z_by_angle,
            find_most_responsive_angles,
            response_by_orientation,
            get_constrained_images,
            plot_avg_response_by_orientation,
        )
    except Exception as e:
        logger.error(f"Failed to import task_vae utils: {e}")
        raise

    # Access configuration safely to avoid KeyError from custom __getattr__ implementation
    try:
        cfg = getattr(params.analysis_params, 'gratings_orientation_analysis')
    except (AttributeError, KeyError):
        cfg = None

    def _safe_get(attr_name, default):
        if cfg is None:
            return default
        try:
            return getattr(cfg, attr_name)
        except (AttributeError, KeyError):
            return default

    spatial_frequency = _safe_get('spatial_frequency', 3)
    contrast = _safe_get('contrast', 1.0)
    bin_size = _safe_get('bin_size', 5.0)
    method = _safe_get('method', 'threshold')
    decay_threshold = _safe_get('decay_threshold', 0.25)
    angles = _safe_get('angles', None)

    logger.info('Generating gratings dataset...')
    data = generate_gratings(
        pca_model=None,
        hist_match=False,
        whiten=False,
        train_fraction=1,
        shuffle=True,
        noise_std=None,
        spherical_mask=False,
        n_cores=-1,
    )
    logger.info(f"Generated {len(data['train_images'])} training images")

    logger.info('Computing average absolute z by angle (tuning) ...')
    avg_abs_z_by_angle = calculate_avg_abs_z_by_angle(
        data,
        model,
        spatial_frequency=spatial_frequency,
        contrast=contrast,
        abs=True,
    )

    logger.info('Finding most responsive angles...')
    most_responsive_angles = find_most_responsive_angles(
        avg_abs_z_by_angle,
        method=method,
        decay_threshold=decay_threshold,
        return_max_values=False,
    )
    num_responsive = int(np.sum(~np.isnan(most_responsive_angles)))
    logger.info(f"Responsive components: {num_responsive} / {len(most_responsive_angles)}")

    if angles is not None and len(angles) > 0:
        # Generate per-angle datasets and plots
        try:
            iterator = list(angles)
        except TypeError:
            iterator = [angles]

        for angle in iterator:
            logger.info(f'Collecting stimulus images for response curve at angle={angle}...')
            stimulus_images = get_constrained_images(
                data,
                angle=float(angle),
                spatial_frequency=spatial_frequency,
                contrast=contrast,
            )
            logger.info(f"Found {len(stimulus_images)} images at angle={angle}, sf={spatial_frequency}, contrast={contrast}")

            logger.info('Computing orientation-binned average responses...')
            orientation_responses = response_by_orientation(
                stimulus_images,
                model,
                most_responsive_angles,
                prior_image=None,
                absolute=True,
                bin_size=bin_size,
                moment='mean',
                threshold=0.0,
            )

            # Plot and log per-angle
            title = f'Binned Orientation Response (sf={spatial_frequency}, c={contrast}, bin={bin_size}°, stim_angle={angle}°)'
            plot_avg_response_by_orientation(orientation_responses, title=title)
            fig = plt.gcf()
            wandb_run.log({f"gratings_orientation_analysis/angle_{int(angle)}": wandb.Image(fig)})
            plt.close(fig)
    else:
        # Fallback: aggregate across all angles (previous behavior)
        logger.info('Collecting stimulus images for response curve (all angles)...')
        stimulus_images = get_constrained_images(
            data,
            spatial_frequency=spatial_frequency,
            contrast=contrast,
        )
        logger.info(f"Found {len(stimulus_images)} images at sf={spatial_frequency}, contrast={contrast}")

        logger.info('Computing orientation-binned average responses...')
        orientation_responses = response_by_orientation(
            stimulus_images,
            model,
            most_responsive_angles,
            prior_image=None,
            absolute=True,
            bin_size=bin_size,
            moment='mean',
            threshold=0.0,
        )

        # Plot and log aggregate
        title = f'Binned Orientation Response (sf={spatial_frequency}, c={contrast}, bin={bin_size}°)'
        plot_avg_response_by_orientation(orientation_responses, title=title)
        fig = plt.gcf()
        wandb_run.log({"gratings_orientation_analysis": wandb.Image(fig)})
        plt.close(fig)
    logger.info('Gratings Orientation Analysis complete')


def generation(model, wandb_run, logger: logging.Logger):
    from .functional import generate

    logger.info('Generating Samples from Prior')
    outputs = generate(model, logger)
    for temp_i, temp_outputs in enumerate(outputs):
        samples = []
        for sample_i, output in enumerate(temp_outputs):
            sample = wandb.Image(output, caption=f'Prior sample {sample_i}')
            samples.append(sample)
        wandb_run.log({f"generation {temp_i}": samples}, step=temp_i)
    logger.info(f'Generation successful')


def extrapolate(model, loader, wandb_run, logger: logging.Logger):
    from .functional import extrapolate as extrapolate_func
    logger.info('Generating Samples with Extrapolation')
    n_samples = params.analysis_params.extrapolation.n_samples
    seq_len = params.analysis_params.extrapolation.seq_len
    original, predictions, means = extrapolate_func(model, loader, seq_len, n_samples)
    means_shape = (params.data_params.shape[0] + seq_len, *params.data_params.shape[1:])
    #means_preds_shape = (params.data_params.shape[0] * 2, *params.data_params.shape[1:])
    
    # EXTRAPOLATION CHECK
    #wandb_run.log({f"test_sample_image": 
    #               wandb.Image(means[0][0])})
    #o0 = wandb.Image(original[0][0], caption=f'Original {0}_{0}')
    #m0 = wandb.Image(means[0][0], caption=f'Mean {0}_{0}')
    #m_first_ext = wandb.Image(means[0][params.data_params.shape[0]], caption=f'Mean {0}_{params.data_params.shape[0]}')
    #m_last_ext = wandb.Image(means[0][means_shape[0]-1], caption=f'Mean {0}_{means_shape[0]-1}')
    #wandb_run.log({f"extrapolation_check": [o0, m0, m_first_ext, m_last_ext]})
    
    means_preds = torch.empty((n_samples, *means_shape))
    for i in range(int(means_shape[0] / 2)):
        means_preds[:, 2*i] = means[:n_samples ,i]
        means_preds[:, 2*i + 1] = predictions[:n_samples, i]

    for i in range(n_samples):
        o = wandb.Image(original[i].reshape(params.data_params.shape), caption=f'Original {i}')
        p = wandb.Image(means_preds[i].reshape(means_shape), caption=f'Means + Predictions {i}')
        m = wandb.Image(means[i].reshape(means_shape), caption=f'Mean {i}')
        wandb_run.log({f"extrapolation_{i}": [o, p, m]})
    logger.info(f'Extrapolation generation successful')


def mei(model, wandb_run, logger: logging.Logger):
    logger.info('Generating Most Exciting Inputs (MEI)')
    for op_name, op in params.analysis_params.mei.items():
        result = generate_mei(model, op["objective"], op["use_mean"],
                              op["type"], op["config"])
        vis = result.get_image().detach().cpu().numpy()
        wandb_run.log({f"MEI {op_name}": wandb.Image(vis)})
    logger.info(f'MEI generation successful')


def generate_mei(model, objective, use_mean, mei_type, mei_config):
    from meitorch.mei import MEI

    def operation(inputs):
        model.eval()
        computed, _ = model(inputs, use_mean=use_mean)
        objective_result = objective(computed)
        if isinstance(objective_result, torch.Tensor):
            return dict(objective=-objective_result,
                        activation=objective_result)
        elif isinstance(objective_result, dict):
            assert 'objective' in objective_result and 'activation' in objective_result, \
                'objective_result must contain keys "objective" and "activation"'
            return objective_result
    mei_object = MEI(operation=operation, shape=params.data_params.shape)

    if mei_type == 'pixel':
        results = mei_object.generate_pixel_mei(**mei_config)
    elif mei_type == 'distribution':
        results = mei_object.generate_variational_mei(**mei_config)
    elif mei_type == 'transform':
        results = mei_object.generate_transformation_based(**mei_config)
    else:
        raise ValueError(f'Unknown MEI type {mei_type}')
    return results


def white_noise_analysis(model, wandb_run, logger: logging.Logger):
    logger.info('Generating Samples with White Noise Analysis')
    shape = params.data_params.shape
    for target_block, config in params.analysis_params.white_noise_analysis.items():
        n_samples = config['n_samples']
        sigma = config['sigma']
        receptive_fields = \
            generate_white_noise_analysis(model, target_block, shape, n_samples, sigma)
        
        n_dims = receptive_fields.shape[0]
        dims_per_image = np.prod(shape[-2:]) // 40
        n_images = n_dims // dims_per_image + 1

        for im in range(n_images):
            n_dims_im = min(dims_per_image, n_dims - im * dims_per_image)
            start_dim = im * dims_per_image
            n_rows = math.ceil(n_dims_im / 20)
            w = int(shape[-2])
            h = int(shape[-1] / 20 * n_rows)

            fig = figure(figsize=(w, h))
            plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
            for i in range(n_dims_im):
                ax = fig.add_subplot(n_rows, 20, i + 1)
                ax.imshow(receptive_fields[start_dim + i].reshape(params.data_params.shape[-2:]), cmap="gray")
                ax.set_title(f"dim {i}")
                ax.axis("off")

            fig.tight_layout()
            wandb_run.log(
                {f"white noise analysis {target_block} - {im}": fig},
            )
            plt.close(fig)


def generate_white_noise_analysis(model, target_block, shape, n_samples=100, sigma=0.6):
    import scipy

    white_noise = np.random.normal(size=(n_samples, np.prod(shape)),
                                   loc=0.0, scale=1.).astype(np.float32)

    # apply ndimage.gaussian_filter with sigma=0.6
    for i in range(n_samples):
        white_noise[i, :] = scipy.ndimage.gaussian_filter(
            white_noise[i, :].reshape(shape), sigma=sigma).reshape(np.prod(shape))

    with torch.no_grad():
        model.eval()
        computed, _ = model(torch.ones(1, *shape, device=params.device), stop_at=target_block)
        target_block_dim = computed[target_block].shape[1:]
        target_block_values = torch.zeros((n_samples, *target_block_dim), device=params.device)

        # loop over a batch of 128 white_noise images
        batch_size = params.analysis_params.batch_size
        for i in range(0, n_samples, batch_size):
            batch = white_noise[i:i+batch_size, :].reshape(-1, *shape)
            computed_target, _ = model(torch.tensor(batch, device=params.device),
                                       use_mean=True, stop_at=target_block)
            target_block_values[i:i+batch_size] = computed_target[target_block]

        target_block_values = torch.flatten(target_block_values, start_dim=1)
        # multiply transpose of target block_values with white noise tensorially
        receptive_fields = np.matmul(
            target_block_values.T.cpu().numpy(), white_noise
        ) / np.sqrt(n_samples)
        return receptive_fields


def decodability(model, labeled_loader, wandb_run, logger: logging.Logger):
    logger.info('Computing Decodability')
    results = calculate_decodability(model, labeled_loader)
    accuracies = {decode_from: accuracy for decode_from, (_, accuracy) in results.items()}
    for decode_from, (loss_history, accuracy) in results.items():
        # loss history -> line plot
        data = [[x, y] for (x, y) in enumerate(loss_history)]
        table = wandb.Table(data=data, columns=["step", "loss"])
        wandb_run.log({"decodability_loss_history":
                      wandb.plot.line(table, "step", "loss",
                                      title=f"Decodability Loss History {decode_from}")})

    # accuracy -> bar plot
    data = [[decode_from, acc] for (decode_from, acc) in accuracies.items()]
    table = wandb.Table(data=data, columns =["decode_from", "accuracy"])
    wandb_run.log({"decodability_accuracies":
                   wandb.plot.bar(table, "decode_from", "accuracy",
                                  title="Deocdability Accuracies")})
    logger.info(f'Decodability calculation successful')


def calculate_decodability(model, labeled_loader):
    from .elements.dataset import FunctionalDataset

    decode_from_list = params.analysis_params.decodability.keys()
    X = {layer: [] for layer in decode_from_list}
    Y = []
    for batch in labeled_loader:
        inputs, labels = batch
        computed, _ = model(inputs, use_mean=True)
        for decode_from in decode_from_list:
            X[decode_from].append(
                computed[decode_from].numpy())
            Y.append(labels)
    X = {block: np.concatenate(block_inputs, axis=0)
         for block, block_inputs in X.items()}
    Y = np.concatenate(Y, axis=0)

    results = dict()
    for decode_from in decode_from_list:
        decoder_model = params.analysis_params.decodability[decode_from]["model"]()
        decoding_dataset = FunctionalDataset(data=X[decode_from], labels=Y)
        optimizer = params.analysis_params.decodability[decode_from]["optimizer"](
            decoder_model.parameters(),
            lr=params.analysis_params.decodability[decode_from]["learning_rate"])
        loss = params.analysis_params.decodability[decode_from]["loss"]()
        loss_history, accuracy = train_decoder(
            decoder_model, optimizer, loss, params.analysis_params.decodability[decode_from]["epochs"],
            params.analysis_params.decodability[decode_from]["batch_size"], decoding_dataset)
        results[decode_from] = (loss_history, accuracy)
    return results


def train_decoder(decoder_model, optimizer, loss, epochs, batch_size, dataset):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    loss_history = []
    # train model
    for epoch in range(epochs):
        for batch in dataloader:
            X, Y = batch
            optimizer.zero_grad()
            output = decoder_model(X)
            batch_loss = loss(output, Y)
            loss_history.append(batch_loss.item())
            batch_loss.backward()
            optimizer.step()

    # evaluate model
    # TODO: add evaluation -> calcualte accuracy
    return loss_history, 0


def latent_step_analysis(model, dataloader, wandb_run, logger: logging.Logger):
    logger.info('Generating Samples with Latent Step Analysis')
    sample = next(iter(dataloader))
    shape = sample.shape[1:]
    for target_block, config in params.analysis_params.latent_step_analysis.items():
        receptive_fields = generate_latent_step_analysis(model, sample, target_block, **config)
        scores, _, _ = active_filters(receptive_fields, params.analysis_params.batch_size, 40e3, wandb_run=wandb_run)
        #import pdb; pdb.set_trace()
        n_dims = len(receptive_fields)
        n_cols = 20
        n_rows = 4
        dims_per_image = n_cols * n_rows
        n_images = math.ceil(n_dims / dims_per_image)
        for im in range(n_images):
            n_dims_im = min(dims_per_image, n_dims - im * dims_per_image)
            start_dim = im * dims_per_image
            w = int(shape[-2])
            h = int(shape[-1] / min(shape[-1], n_cols) * n_rows)

            fig = figure(figsize=(w, h))
            plt.subplots_adjust(left=0, right=1, top=0.92, bottom=0.05, hspace=0.4)
            for i in range(n_dims_im):
                ax = fig.add_subplot(n_rows, n_cols, i + 1)
                ax.imshow(receptive_fields[start_dim + i].reshape(params.data_params.shape[-2:]), cmap="gray")
                latent_idx = start_dim + i
                ax.set_title(f"latent {latent_idx}\nscr: {scores[latent_idx]:.2f}", fontsize=8)
                ax.axis("off")
            
            wandb_run.log(
                {f"latent step analysis {target_block} - image {im}": fig},
            )
            plt.close(fig)
    logger.info('Latent Step Analysis Successful Images')


def generate_latent_step_analysis(model, sample, target_block, diff=1, value=1):
    def copy_computed(computed):
        return {k: v.clone() for k, v in computed.items()}

    with torch.no_grad():
        model.eval()
        target_computed, _ = model(sample.to(params.device), use_mean=True, stop_at=target_block)
        #import pdb; pdb.set_trace()
        input_0 = target_computed[target_block]
        shape = input_0.shape
        n_dims = np.prod(shape[1:])

        computed_checkpoint = copy_computed(target_computed)
        output_computed, _ = model(computed_checkpoint, use_mean=True)
        output_0 = torch.mean(output_computed['output'], dim=0)

        visualizations = []
        for i in range(n_dims):
            input_i = torch.zeros([1, n_dims], device=params.device)
            input_i[0, i] = value
            input_i = input_i.reshape(shape[1:])
            input_i = input_0 + input_i
            target_computed[target_block] = input_i

            computed_checkpoint = copy_computed(target_computed)
            trav_output_computed, _ = model(computed_checkpoint, use_mean=True)
            #import pdb; pdb.set_trace()
            output_i = torch.mean(trav_output_computed['output'], dim=0)

            latent_step_vis = output_i - diff * output_0
            visualizations.append(latent_step_vis.detach().cpu().numpy())
        #import pdb; pdb.set_trace()

        return visualizations
    
    
# wavelet uncertainty
def feri_score(batch):
    
    '''
    calculates the Feri metric, ie. a quantity whose continuous analogue is said to be minimized by Gabor filters
    code adapted by F. Csikor as name suggests, we just extend it to support batches
    '''
    
    from scipy import ndimage
    
    def uncertainty(img):
        
        def pos_var(img):
            img_abs_sq = np.abs(img * np.conjugate(img))
            c_of_mass = ndimage.measurements.center_of_mass(img_abs_sq)
            ii, jj = np.meshgrid(np.arange(img.shape[0]) - c_of_mass[0],
                                 np.arange(img.shape[1]) - c_of_mass[1],
                                 indexing='ij')
            sqdist_from_c_of_mass = np.square(ii) + np.square(jj)
            return ndimage.measurements.sum(sqdist_from_c_of_mass * img_abs_sq) / \
                ndimage.measurements.sum(img_abs_sq)
                
        if img.ndim == 1:
            dx = int(np.sqrt(img.shape[0]))
            img = img.reshape((dx, dx))
        if img.ndim == 3:
            img = img[0, :, :]
        if img.ndim == 4:
            img = img[0, 0, :, :]
        position_variance = pos_var(img)
        wave_number_variance = pos_var(np.fft.fft2(img))
        return position_variance * wave_number_variance

    scores=[]
    for i in range(len(batch)):
        scores.append(uncertainty(batch[i]))
    return scores

def active_filters(receptive_fields, batch_size, threshold=40e3, global_step=None, wandb_run=None):
    '''
    calculates the active filters for batches of receptive fields
    '''
    import math
    n_batches = math.ceil(len(receptive_fields) / batch_size)
    scores = []
    
    if isinstance(receptive_fields, torch.Tensor):
        receptive_fields = receptive_fields.detach().cpu().numpy()
    
    for i in range(n_batches):
        start = i*batch_size
        end = min((i+1)*batch_size, len(receptive_fields))
        batch = receptive_fields[start:end]
        
        batch_scores = feri_score(batch)
        scores.extend(batch_scores)
        
    n_active = 0
    for score in scores:
        if score < threshold:
            n_active += 1
            
    
    if wandb_run is not None:
        try:
            if global_step is None:
                histogram_data = [[s] for s in scores]
                histogram_table = wandb.Table(data=histogram_data, columns=["uncertainty"])
                wandb_run.log({"Feri score histogram": 
                                    wandb.plot.histogram(histogram_table, "uncertainty", title="Filters by uncertainty histogram"), 
                            "active filters": 
                                    wandb.Table(data=[[n_active, n_active/len(scores)]], columns=["active filters", "active filters ratio"])})

            else:
                np_histogram = np.histogram(scores, bins=100)
                histogram = wandb.Histogram(np_histogram=np_histogram)
                wandb_run.log({"filters/active filters": n_active,
                            "filters/active filters ratio": n_active / len(scores),
                            "filters/Feri score histogram": histogram}, 
                                step=global_step),
        except ValueError as e: print(e)

    return scores, n_active, n_active / len(scores)


        
    





