import seaborn as sns
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd
import math

import json, yaml
import os
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.figure_factory as ff

import torch
import torch.nn.functional as F
from torch import nn
from torchdyn.core import NeuralODE
#from model import NODELatentSAE


def get_influence_at(model, e1, e2, I, input_step):
    """
    Run your mini-MLP once at (e1,e2,I) and grab the raw contributions.
    Returns: dict src → dict tgt → float
    """
    tensor = torch.tensor([[e1, e2, I]]).float()
    # you'd call your decoder.cell() just like in mlp_flow  
    if type(model.decoder) == NeuralODE:
        return None
    
    result = model.decoder.cell(input_step, tensor)
    # Handle both single and double return values
    if isinstance(result, tuple):
        _, results = result
    else:
        results = None
        
    # unpack out of results into a simple dict of floats
    if results is not None:
        inf = { src:{ tgt: results[src].get(tgt, torch.tensor(0.)).item()
                    for tgt in ['E1','E2','I'] }
                for src in ['E1','E2','I'] }
        return inf
    else:
        return None
    
def mlp_flow(model, E1_min, E1_max, E2_min, E2_max, I, input_step=None):
    hidden_input = []
    hidden_output = []

    # Structure to hold full vector data by source → target
    influence_arrows = {
        'E1': {'E1': [], 'E2': [], 'I': []},
        'E2': {'E1': [], 'E2': [], 'I': []},
        'I':  {'E1': [], 'E2': [], 'I': []}
    }
    results = None
    for e1 in np.arange(E1_min, E1_max, 0.5):
        for e2 in np.arange(E2_min, E2_max, 0.5):
            tensor = torch.tensor([[e1, e2, I]]).float()
            hidden_input.append(tensor)

            if type(model.decoder) == NeuralODE:
                total_hidden = model.cell(tensor)
            else:
                cell_output = model.decoder.cell(input_step, tensor)
                if isinstance(cell_output, tuple):
                    total_hidden, results = cell_output
                else:
                    total_hidden = cell_output
                    results = None

            if results is not None:
                for src_neuron in results.keys():  # e.g., 'E1', 'E2', 'I'
                    contrib_dict = results[src_neuron]
                    for tgt_neuron in ['E1', 'E2', 'I']:
                        if tgt_neuron in contrib_dict:
                            value = contrib_dict[tgt_neuron].detach().cpu().squeeze().item()
                            influence_arrows[src_neuron][tgt_neuron].append(value)
                        else:
                            influence_arrows[src_neuron][tgt_neuron].append(0.0)  # fill 0 if missing

            hidden_output.append(total_hidden)

    if results is not None:
        # Convert all influence lists to tensors
        for src in influence_arrows:
            for tgt in influence_arrows[src]:
                influence_arrows[src][tgt] = torch.tensor(influence_arrows[src][tgt]).float()

    hidden_input = torch.stack(hidden_input, dim=1)        # (1, N, 3)
    hidden_output = torch.stack(hidden_output, dim=1)      # (1, N, 3)

    if results is not None:
        return hidden_input, hidden_output, influence_arrows
    else:
        return hidden_input, hidden_output, None

def plot_flowfield(model, pred_latents, config, validate_only=False):

    if type(model.decoder) == NeuralODE:
        model_type = "NODE"
    else:
        model_type = "RNN"
    model.eval()

    # move model to cpu
    if hasattr(model, 'device') and str(model.device) != 'cpu': # Check attribute and compare device as string
        model = model.to('cpu')

    E1_min, E1_max = get_range(pred_latents, neuron="E1") # int(np.min(pred_latents[:, :, 0]) * p_size)
    E2_min, E2_max = get_range(pred_latents, neuron="E2")

    I_min = math.floor(np.min(pred_latents[:, :, 2]))
    I_max = math.ceil(np.max(pred_latents[:, :, 2]))

    I10 = sum(pred_latents[:, 10, 2])/len(pred_latents[:, 10, 2])
    I50 = sum(pred_latents[:, 50, 2])/len(pred_latents[:, 50, 2])
    I90 = sum(pred_latents[:, 90, 2])/len(pred_latents[:, 90, 2])

    # model.decoder.cell  MiniMLP
    # model.decoder.vf   NODE
    
    if config['noise'] == 0:
        input_tensor = torch.tensor([[0.]], dtype=torch.float32)
    elif config['noise'] == 1:
        input_tensor = torch.tensor([[1.]], dtype=torch.float32)
    else:
        input_tensor = torch.tensor([[1., 1.]], dtype=torch.float32)
        input_tensor1 = torch.tensor([[1., 0.]], dtype=torch.float32)
    
    hidden_input1, hidden_output1, influence1 = mlp_flow(model, E1_min, E1_max, E2_min, E2_max, I10, input_step=input_tensor)
    hidden_input2, hidden_output2, influence2 = mlp_flow(model, E1_min, E1_max, E2_min, E2_max, I50, input_step=input_tensor)
    hidden_input3, hidden_output3, influence3 = mlp_flow(model, E1_min, E1_max, E2_min, E2_max, I90, input_step=input_tensor)

    if validate_only:
        flowfield_plotly(model, hidden_input1, hidden_output1, config=config, pred_latents=pred_latents, arrows=influence1, n=10, model_type=model_type)
        flowfield_plotly(model, hidden_input2, hidden_output2, config=config, pred_latents=pred_latents, arrows=influence2, n=50, model_type=model_type)
        flowfield_plotly(model, hidden_input3, hidden_output3, config=config, pred_latents=pred_latents, arrows=influence3, n=90, model_type=model_type)
    else:
        _, _, _, _ = flowfield_matplotlib(hidden_input1, hidden_output1, config, pred_latents, n=10)
        _, _, _, _ = flowfield_matplotlib(hidden_input2, hidden_output2, config, pred_latents, n=50)
        plot_e1e290, _, _, _ = flowfield_matplotlib(hidden_input3, hidden_output3, config, pred_latents, n=90)

        return plot_e1e290

def flowfield_plotly(model, input_flow, output_flow, config=None, pred_latents=None, arrows=None, n=None, model_type='NODE'):
    # Plotly flow field rendering and saving

    input_E1 = input_flow[0, :, 0] 
    input_E2 = input_flow[0, :, 1]
    input_I = input_flow[0, :, 2]

    output_E1 = output_flow[0, :, 0] - input_E1
    output_E2 = output_flow[0, :, 1] - input_E2
    output_I = output_flow[0, :, 2] - input_I

    path = os.path.join(config['BASE_PATH'], 'flowfield')
    os.makedirs(path, exist_ok=True)  # Ensure the directory exists

    neuron_pairs = [
        ('E1', 'E2', input_E1, input_E2, output_E1, output_E2),
        ('E1', 'I', input_E1, input_I, output_E1, output_I),
        ('E2', 'I', input_E2, input_I, output_E2, output_I)
    ]

    for pair in neuron_pairs:
        neuron1, neuron2, input1, input2, delta1, delta2 = pair

        if model_type == "NODE":
            fig = ff.create_quiver(input1.detach().numpy(), input2.detach().numpy(), delta1.detach().numpy(), delta2.detach().numpy(), arrow_scale=0.3)
        else:
            fig = ff.create_quiver(input1.detach().numpy(), input2.detach().numpy(), delta1.detach().numpy(), delta2.detach().numpy(), scale=2)
        # arrow_scale=0.3 ,   scale=2

        fixed_trials = [0,1,2,9,10]

        for trial_idx in fixed_trials:       
            fig.add_trace(go.Scatter(
                x=pred_latents[trial_idx,:,0],
                y=pred_latents[trial_idx,:,1],
                mode='lines+markers',
                name=f'Trial {trial_idx+1}'
            ))

        colors = {'E1':'crimson','E2':'forestgreen','I':'goldenrod'}

        # assume input_step and I_val are available in this scope:
        I_val = float(input_flow[0,0,2])

        for trial_idx in fixed_trials:
            x0 = pred_latents[trial_idx,-1,0].item()
            y0 = pred_latents[trial_idx,-1,1].item()

            # compute influence at that point
            inf = get_influence_at(model, x0, y0, I_val, torch.tensor([[0.]], dtype=torch.float32))

            if inf is not None:
            # draw one arrow per source neuron
                for src in ['E1','E2','I']:
                    # projection onto the E1–E2 plane:
                    dx = inf[src]['E1']
                    dy = inf[src]['E2']

                    fig.add_annotation(
                        x=x0+dx, y=y0+dy,
                        ax=x0,   ay=y0,
                        xref='x', yref='y', axref='x', ayref='y',
                        showarrow=True,
                        arrowhead=3,
                        arrowsize=1,
                        arrowwidth=2,
                        arrowcolor=colors[src],
                        standoff=2
                    )

        fig.update_layout(
            title=f'Flow Field: {neuron1} vs {neuron2}',
            xaxis_title=neuron1,
            yaxis_title=neuron2
        )

        filepath = f'{path}/{neuron1}_{neuron2}_{n}.png'


        fig.write_image(filepath)

### 2. Plotting 3D Flow Fields with Plotly:

    fig = go.Figure(data=go.Cone(
        x=input_E1.detach().numpy(),
        y=input_E2.detach().numpy(),
        z=input_I.detach().numpy(),
        u=output_E1.detach().numpy(),
        v=output_E2.detach().numpy(),
        w=output_I.detach().numpy(),
        sizemode="absolute",
        sizeref=2
    ))

    fig.update_layout(
        title='3D Flow Field',
        scene=dict(
            xaxis_title='E1',
            yaxis_title='E2',
            zaxis_title='I'
        )
    )

    filepath = f'{path}/flow.png'
    #fig.write_html(filepath)
    fig.write_image(filepath)



def plot_prediction(data, prediction, config=None, target='E1'):
    #     #torch.Size([16, 100, 880])
    # spikes -> data -> data_dl -> test_dl / val_dl
    # pred_rates -> prediction
    # spikes = data.dataset.tensors[0] #.numpy()
    # TODO: this is horrible
    original_trials = get_spike(data, target, config=config) # Pass config here
    #original_trials = data #get_spike(data, target)
    predicted_trials = get_spike(prediction, target, config=config) # and here
    #predicted_trials = prediction #get_spike(prediction, target)

    num_neurons = len(original_trials)

    # Define the number of rows and columns for the subplots
    n_trials_total = len(original_trials)
    max_trials = 16
    if n_trials_total > max_trials:
        # Evenly sample trial indices to keep representative coverage
        indices = np.linspace(0, n_trials_total - 1, max_trials, dtype=int)
    else:
        indices = np.arange(n_trials_total)
    n_trials = len(indices)
    n_cols = 4  # Number of columns
    n_rows = n_trials // n_cols + int(n_trials % n_cols != 0)  # Number of rows

    # Create subplots
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(20, n_rows * 5))

    # Plot each trial's average spikes in a separate subplot
    for idx, trial in enumerate(indices):
        row = idx // n_cols
        col = idx % n_cols
        ax = axs[row, col]
        ax.plot(original_trials[trial], label='Original')
        ax.plot(predicted_trials[trial], label='Pred', linestyle='--')
        ax.set_title(f'Trial {trial + 1}')
        ax.set_xlabel('Time Points')
        ax.set_ylabel('Average Spikes')

    # Remove empty subplots
    for trial in range(n_trials, n_rows * n_cols):
        fig.delaxes(axs.flatten()[trial])

    plt.tight_layout(pad=2.0)
    plt.legend(loc='upper right', bbox_to_anchor=(1.1, 1.05))

    path = os.path.join(config['BASE_PATH'], 'prediction')
    os.makedirs(path, exist_ok=True)  # Ensure the directory exists
    filepath = f'{path}/{target}_prediction.png'
    
    plt.savefig(filepath)
    # plt.close()
    return fig


def plot_lograte(data, n_cols=4, config=None, target='E1'):
    original_trials = get_spike(data, target, config=config) # Pass config here

    n_trials_total = len(original_trials)
    max_trials = 16
    if n_trials_total > max_trials:
        indices = np.linspace(0, n_trials_total - 1, max_trials, dtype=int)
    else:
        indices = np.arange(n_trials_total)
    n_trials = len(indices)
    n_rows = (n_trials + n_cols - 1) // n_cols  # Calculate number of rows needed

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 10))
    fig.suptitle('Spike Data for Each Trial', fontsize=16)
    
    axes = axes.flatten()  # Flatten the axes array for easy indexing

    for plot_idx, trial_idx in enumerate(indices):
        ax = axes[plot_idx]
        ax.plot(original_trials[trial_idx], label='Logrates')
        ax.set_title(f'Trial {trial_idx + 1}')
        ax.set_xlabel('Timepoints')
        ax.set_ylabel('Lograte Value')
        ax.legend()

    plt.tight_layout(rect=[0, 0, 1, 0.95])

    path = os.path.join(config['BASE_PATH'], 'lograte')
    os.makedirs(path, exist_ok=True)  # Ensure the directory exists
    filepath = f'{path}/{target}_lograte.png'
    
    plt.savefig(filepath)
    # plt.close()
    return fig
    

def flowfield_matplotlib(input_flow, output_flow, config=None, pred_latents=None, n=None):
    input_E1 = input_flow[0, :, 0].detach().numpy()
    input_E2 = input_flow[0, :, 1].detach().numpy()
    input_I = input_flow[0, :, 2].detach().numpy()

    output_E1 = (output_flow[0, :, 0] - input_flow[0, :, 0]).detach().numpy()
    output_E2 = (output_flow[0, :, 1] - input_flow[0, :, 1]).detach().numpy()
    output_I = (output_flow[0, :, 2] - input_flow[0, :, 2]).detach().numpy()

    epoch = config['epoch']
    lr = config['lr']
    data_type = config['data']
    # deprecated: timepoint-based naming removed
    encoder_size = config['encoder_size']

    path = os.path.join(config['BASE_PATH'], 'flowfield')
    os.makedirs(path, exist_ok=True)  # Ensure the directory exists

    neuron_pairs = [
        ('E1', 'E2', input_E1, input_E2, output_E1, output_E2),
        ('E1', 'I', input_E1, input_I, output_E1, output_I),
        ('E2', 'I', input_E2, input_I, output_E2, output_I)
    ]

    plots = []
    for pair in neuron_pairs:
        neuron1, neuron2, input1, input2, delta1, delta2 = pair

        # Create 2D Quiver plot
        fig = plt.figure(figsize=(8, 6))
        plt.quiver(input1, input2, delta1, delta2,
                   angles='xy', scale_units='xy', scale=2.,
                   color='r', alpha=0.7)
        
        if pred_latents is not None:
            for i in range(pred_latents.shape[0]):  # Plot all latent points
                plt.scatter(pred_latents[i, :, 0], 
                            pred_latents[i, :, 1], 
                            label=f'Trial {i+1}', alpha=0.6)

        plt.title(f'Flow Field: {neuron1} vs {neuron2}')
        plt.xlabel(neuron1)
        plt.ylabel(neuron2)
        plt.grid(True)
        plots.append(fig)
        #filepath = f'{path}/{neuron1}_{neuron2}_{n}.html'
        filepath = os.path.join(path, f'{neuron1}_{neuron2}_{n}.png')
        plt.savefig(filepath)
        # plt.close()

    # 3D Flow Field Plot
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.quiver(input_E1, input_E2, input_I, output_E1,
              output_E2, output_I, length=0.2, normalize=True, color='b')

    ax.set_title('3D Flow Field')
    ax.set_xlabel('E1')
    ax.set_ylabel('E2')
    ax.set_zlabel('I')
    plots.append(fig)

    filepath_3d = os.path.join(path, f'flow_3D_{n}.png')
    plt.savefig(filepath_3d)
    # plt.close()

    return plots[0], plots[1], plots[2], plots[3]  # Return the 2D and 3D plots



def plot_units(data, config=None):
    num_trials, num_timepoints, num_units = data.shape
    
    if num_trials > 16:
        num_trials = 16
    # Determine the grid size for subplots
    cols = 4  # Number of columns
    rows = (num_trials + cols - 1) // cols  # Calculate the number of rows needed

    # Create a figure with subplots
    fig, axes = plt.subplots(rows, cols, figsize=(20, 15))

    # Flatten the axes array for easy iteration
    axes = axes.flatten()

    # Plotting each unit over time for each trial
    for trial in range(num_trials):
        ax = axes[trial]

        ax.plot(data[trial, :, 0], label=f'E1')
        ax.plot(data[trial, :, 1], label=f'E2')
        ax.plot(data[trial, :, 2], label=f'I')

        ax.set_title(f'Trial {trial + 1}')
        ax.set_xlabel('Timepoints')
        ax.set_ylabel('Activation')
        ax.legend()

    # Hide any unused subplots
    for i in range(num_trials, len(axes)):
        fig.delaxes(axes[i])

    path = os.path.join(config['BASE_PATH'], 'units')
    os.makedirs(path, exist_ok=True)  # Ensure the directory exists
    filepath = f'{path}/units.png'

    plt.tight_layout()
    plt.savefig(filepath)  # Save the figure as a PNG file
    plt.close()


# def plot_unit(data, config=None, target="E1"):
#     num_trials, num_timepoints, num_units = data.shape

#     if num_trials > 16:
#         num_trials = 16

#     # Determine the grid size for subplots
#     cols = 4  # Number of columns
#     rows = (num_trials + cols - 1) // cols  # Calculate the number of rows needed

#     # Create a figure with subplots
#     fig, axes = plt.subplots(rows, cols, figsize=(20, 15))

#     # Flatten the axes array for easy iteration
#     axes = axes.flatten()

#     # Plotting each unit over time for each trial
#     for trial in range(num_trials):
#         ax = axes[trial]
#         if target=="E1":
#             ax.plot(data[trial, :, 0], label=f'E1')
#         elif target=="E2":
#             ax.plot(data[trial, :, 1], label=f'E2')
#         else:
#             ax.plot(data[trial, :, 2], label=f'I')
#         ax.set_title(f'Trial {trial + 1}')
#         ax.set_xlabel('Timepoints')
#         ax.set_ylabel('Activation')
#         ax.legend()

#     # Hide any unused subplots
#     for i in range(num_trials, len(axes)):
#         fig.delaxes(axes[i])

#     path = config['BASE_PATH'] + f'/{config["epoch"]}Epoch/{config["lr"]}LR/{config["encoder_size"]}Encoder/{config["data"]}'
#     os.makedirs(path, exist_ok=True)  # Ensure the directory exists
#     filepath = f'{path}/unit/{config["filename_tag"]}_{target}.png'

#     plt.tight_layout()
#     plt.savefig(filepath)  # Save the figure as a PNG file
#     plt.close()


def plot_loss(loss, config=None):

    plt.plot(range(len(loss)), loss, 'b', label='Loss')
    plt.title('Loss vs Epoch')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    path = os.path.join(config['BASE_PATH'], 'loss')
    os.makedirs(path, exist_ok=True)  # Ensure the directory exists
    filepath = f'{path}/loss.png'

    plt.savefig(filepath)
    plt.close()


def neuron_matrix_save(data, t):
    columns = list(data.keys())
    rows = list(data[columns[0]].keys())

    # Prepare the matrix data
    matrix = []
    for row in rows:
        row_values = []
        for col in columns:
            if col in data and row in data[col]:
                row_values.append(data[col][row].mean().item())
            else:
                row_values.append(None)
        matrix.append(row_values)

    matrix_df = pd.DataFrame(matrix, index=rows, columns=columns)

    # Save as CSV
    matrix_df.to_csv(f"matrix_data-{t}.csv")
    return matrix_df

def get_range(latents, neuron="E1", resize=3):
    if neuron == "E1":
        neuron_min = int(np.min(latents[:, :, 0])) - resize 
        neuron_max = int(np.max(latents[:, :, 0])) + resize
    elif neuron == "E2":
        neuron_min = int(np.min(latents[:, :, 1])) - resize 
        neuron_max = int(np.max(latents[:, :, 1])) + resize
    else:
        neuron_min = int(np.min(latents[:, :, 2])) - resize 
        neuron_max = int(np.max(latents[:, :, 2])) + resize

    return neuron_min, neuron_max


def plot_dependency(model, pred_latents, config):
    def plot(ax, x, MLP, neuron="E1", target_neuron="E1"):
        param1 = MLP.params_1[f'{neuron}_1']
        param2 = MLP.params_2[f'{neuron}_2']
        bias = MLP.biases[f'{neuron}_bias']

        weight_1 = F.softplus(param1)
        weight_2 = F.softplus(param2)

        tanh = nn.Tanh()
        x_n = tanh(F.linear(torch.from_numpy(x[:, None]).float(), weight_1, bias=bias))
        x_n = F.softplus(F.linear(x_n, weight_2))

        if "I" in neuron:
            x_n = -x_n

        ax.plot(x, x_n.detach().numpy().squeeze(), label='SoftPlus')
        ax.set_title(f'f({neuron}) vs {neuron} - {target_neuron} MLP')
        ax.set_xlabel(neuron)
        ax.set_ylabel(f'f({neuron})')

        if neuron == target_neuron:
            plot_addition(ax, x_range, MLP, neuron=target_neuron, target_neuron=target_neuron)
    
    def plot_addition(ax, x, MLP, neuron="E1", target_neuron="E1"):
        param1 = MLP.a_target_1
        param2 = MLP.a_target_2
        bias = MLP.a_target_bias

        tanh = nn.Tanh()

        # x[:, None].unsqueeze(1)
        x_a = tanh(F.linear(torch.from_numpy(x[:, None]).float(), param1, bias=bias))
        x_a = F.linear(x_a, param2)

        if "I" in neuron:
            x_a = -x_a

        ax.plot(x, x_a.detach().numpy().squeeze(), label='Addition')
        ax.legend()


    def plot_noise(ax, x, MLP, noise="1", target_neuron="E1"):
        param1 = MLP.noise_params_1[f'{noise}_1']
        param2 = MLP.noise_params_2[f'{noise}_2']

        x_n = F.linear(torch.from_numpy(x).float(), param1)

        ax.axhline(y=x_n.detach().numpy().squeeze(), color='r', linestyle='--')
        ax.set_title(f'Noise ({noise}) - {target_neuron} MLP')
        ax.set_xlabel("Noise Output")
        ax.set_ylabel("Noise Input")


    neuron_names = list(config['neurons'].keys())
    n_neurons = len(neuron_names)
    n_noises = config.get('noise', 0)

    x_ranges = []
    for neuron in neuron_names:
        x_min, x_max = get_range(pred_latents, neuron=neuron)
        x_ranges.append((np.arange(x_min, x_max, 0.1), neuron))


    # E1 = model.decoder.vf.miniMLPs['E1'] #model.decoder.cell.miniMLPs['E1']
    # E2 = model.decoder.vf.miniMLPs['E2']
    # I = model.decoder.vf.miniMLPs['I']

    # E1_min, E1_max = get_range(pred_latents, neuron="E1")
    # E2_min, E2_max = get_range(pred_latents, neuron="E2")
    # I_min, I_max = get_range(pred_latents, neuron="I")

    # x_ranges = [
    #     (np.arange(E1_min, E1_max, 0.1), "E1"),
    #     (np.arange(E2_min, E2_max, 0.1), "E2"),
    #     (np.arange(I_min, I_max, 0.1), "I")
    # ]

    # Create a 3x3 grid for the subplots
    fig, axes = plt.subplots(3, 5, figsize=(15, 15))
    fig.suptitle("Neuron Response Dependency Plots", fontsize=16)

    if n_neurons == 1:  # special case if only 1 neuron
        axes = np.array([[axes]])
    
    for i, (x_range, neuron) in enumerate(x_ranges):
        for j, target_neuron in enumerate(neuron_names):
            MLP = model[target_neuron]  #model.decoder.cell.miniMLPs['E1']
            plot(axes[i, j], x_range, MLP, neuron=neuron, target_neuron=target_neuron)

        for j in range(n_noises):
            MLP = model[neuron]  # model.decoder.cell.miniMLPs['E1']
            plot_noise(axes[i, n_neurons + j], np.array([1]), MLP, noise_id=j, target_neuron=neuron)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])


    # for i, (x_range, neuron) in enumerate(x_ranges):
    #     plot(axes[0, i], x_range, E1, neuron=neuron, target_neuron="E1")
    #     plot(axes[1, i], x_range, E2, neuron=neuron, target_neuron="E2")
    #     plot(axes[2, i], x_range, I, neuron=neuron, target_neuron="I")
    
    # for j in range(config['noise']):
    #     plot_noise(axes[0, 3+j], np.array([1]), E1, noise=j, target_neuron="E1")
    #     plot_noise(axes[1, 3+j], np.array([1]), E2, noise=j, target_neuron="E2")
    #     plot_noise(axes[2, 3+j], np.array([1]), I, noise=j, target_neuron="I")


    # Adjust layout for better spacing
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    path = os.path.join(config['BASE_PATH'], 'neuron response')
    os.makedirs(path, exist_ok=True)


    filepath = f'{path}/Neuron_Response.png'

    plt.savefig(filepath)
    plt.close()



def plot_heatmaps(data, config, file_tag=''):
    path = os.path.join(config['BASE_PATH'], 'heatmaps')
    os.makedirs(path, exist_ok=True)

    # Normalize data
    data_min = data.min(axis=1, keepdims=True)
    data__max = data.max(axis=1, keepdims=True)
    normalized = (data - data_min) / (data__max - data_min)  #Divide by sum?

    fig = plt.figure(figsize=(10,10))
    sns.heatmap(normalized, annot=False, cmap='coolwarm', fmt=".2f", cbar=True)
    plt.title("Heatmap")
    
    if config['l1_reg'] is not None:
        path = os.path.join(path, 'L1')
        os.makedirs(path, exist_ok=True)
        filepath = f'{path}/Heatmap{file_tag}.png'
    else:
        path = os.path.join(path, 'Heatmap')
        os.makedirs(path, exist_ok=True)
        filepath = f'{path}/Heatmap{file_tag}.png'

    plt.savefig(filepath)
    # plt.close()
    return fig

def plot_influence(data, config):
    df = neuron_matrix_save(data, 't')

    num_targets = len(df.columns)
    num_sources = len(df.index)
    bar_width = 0.25
    group_spacing = 0.4  # Space between each target group
    x_base = np.arange(num_targets) * (num_sources * bar_width + group_spacing)  # target neurons
    colors = sns.color_palette("Set2", n_colors=num_sources)

# Initialize plot
    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot each source neuron
    for i, (source, color) in enumerate(zip(df.index, colors)):
        x_positions = x_base + i * bar_width
        bars = ax.bar(
            x_positions,
            df.loc[source],
            width=bar_width,
            label=f'{source}',
            color=color,
            edgecolor='black'
        )
        # Add labels
        for bar in bars:
            height = bar.get_height()
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                height + 0.02,
                f'{height:.2f}',
                ha='center',
                va='bottom',
                fontsize=10
            )

    # Customize axes
    ax.set_xticks(x_base + bar_width)
    ax.set_xticklabels(df.columns, fontsize=12)
    ax.set_xlabel("Target Neuron", fontsize=14)
    ax.set_ylabel("Influence Value", fontsize=14)
    ax.set_title("Neuron-to-Neuron Influence Matrix", fontsize=16, fontweight='bold')
    ax.legend(title="Source Neuron", fontsize=11, title_fontsize=12)

    # Fancy grid and tight layout
    ax.yaxis.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()

    path = os.path.join(config['BASE_PATH'], 'Influence')
    os.makedirs(path, exist_ok=True)
    png_filepath = f'{path}/Influence.png'
    svg_filepath = f'{path}/Influence.svg'

    # Optional: save to PNG
    # plt.savefig("influence_matrix_fancy.png", dpi=300)
    plt.savefig(svg_filepath, format="svg")
    plt.savefig(png_filepath)
    plt.show()
    plt.close()

def _to_3d_tensor(data):
    """
    Normalize input into a torch.Tensor with shape (trials, timepoints, neurons).
    Supports torch.Tensor, np.ndarray, or list of per-trial arrays/tensors.
    """
    if isinstance(data, torch.Tensor):
        if data.dim() == 3:
            return data
        elif data.dim() == 2:
            return data.unsqueeze(0)
        else:
            raise ValueError(f"Unsupported tensor dims for data: {data.dim()}")
    elif isinstance(data, np.ndarray):
        tensor_data = torch.from_numpy(data)
        return _to_3d_tensor(tensor_data)
    elif isinstance(data, list):
        if len(data) == 0:
            raise ValueError("Empty list provided for data")
        # Allow list of 2D [T,N] (treated as trials) or list of 3D [B,T,N] (concatenate batches)
        trial_tensors = []
        for trial in data:
            if isinstance(trial, torch.Tensor):
                t = trial
            else:
                t = torch.from_numpy(np.asarray(trial))
            if t.dim() == 2:
                # [T, N] -> add batch
                t = t.unsqueeze(0)
            elif t.dim() == 3:
                pass
            else:
                raise ValueError("Each list item must be 2D [time, neurons] or 3D [batch, time, neurons]")
            trial_tensors.append(t)
        # Concatenate along batch dimension
        concatenated = torch.cat(trial_tensors, dim=0)
        return concatenated
    else:
        raise TypeError("Unsupported data type. Expected Tensor, ndarray, or list")


def get_spike(data, target='E1', neuron=None, config=None): # Add config as parameter

    if config is None or 'data_type' not in config:
        # Fallback or error if config is not provided or doesn't have data_type
        # For now, let's try the old path, but ideally, this should be an error
        # or a more robust fallback.
        # Consider raising an error: raise ValueError("Config with data_type is required for get_spike")
        print("Warning: Using default data.json in get_spike due to missing config or data_type.")
        config_file_path = "./config/dataset/wang_100T.yaml"
    else:
        config_file_path = f"./config/dataset/{config['data_type']}.yaml"
        # print(config_file_path)

    with open(config_file_path, "r") as f: # Use dynamic path
        data_config = yaml.safe_load(f)
    
    neurons_config = data_config["neurons"]

    data_3d = _to_3d_tensor(data)
    n_trials = data_3d.shape[0]
    timepoints = data_3d.shape[1]

    if target not in neurons_config:
        raise ValueError(f"Target '{target}' not found in neurons_config.")

    start, end = map(int, neurons_config[target])
    
    trials = []

    for i in range(n_trials):
        target_spike = []
        for j in range(timepoints):
            spike = data_3d[i, j] # 1D tensor for all neurons at this (trial, timepoint)

            if neuron is None:
                population = spike[start:end] # This is a slice, still a tensor
                # Check if population is empty to avoid division by zero or issues with sum
                if len(population) > 0:
                    # Ensure population is float for sum/division, then move to CPU and get item
                    mean_spike_val = (torch.sum(population.float()) / len(population)).cpu().item()
                else:
                    mean_spike_val = 0.0 # Default for empty population
                target_spike.append(mean_spike_val)
            else:
                idx = start + neuron
                if idx >= end:
                    raise IndexError(f"Neuron index {neuron} out of range for population '{target}'")
                # Access element, then move to CPU and get item
                spike_val = spike[idx].cpu().item()
                target_spike.append(spike_val)
        trials.append(target_spike)
    return trials



 


def _pseudo_r2_per_neuron(preds, targets):
    """
    Compute pseudo-R² per neuron, matching metrics.pseudo_r2_score formulation.
    Accepts tensors/ndarrays/lists with shape broadcastable to (trials, timepoints, neurons).
    Returns a 1D tensor of length = neurons.
    """
    preds_3d = _to_3d_tensor(preds).float()
    targets_3d = _to_3d_tensor(targets).float()
    preds_2d = preds_3d.reshape(-1, preds_3d.shape[-1])
    targets_2d = targets_3d.reshape(-1, targets_3d.shape[-1])

    eps = 1e-10
    # Following the same terms used in metrics.pseudo_r2_score
    log_likelihood = torch.sum(targets_2d * torch.log(preds_2d + eps) - preds_2d, dim=0)
    log_likelihood_p = torch.sum(targets_2d * torch.log(targets_2d + eps) - preds_2d, dim=0)
    target_mean = torch.mean(targets_2d, dim=0)
    log_likelihood_null = torch.sum(targets_2d * torch.log(target_mean + eps) - target_mean, dim=0)
    pseudo_r2 = 1 - (log_likelihood - log_likelihood_p) / (log_likelihood_null - log_likelihood_p + 1e-12)
    return pseudo_r2


def plot_prediction_by_pseudo_r2(spikes, pred_rates, config=None, target='E1', top_k=3, bottom_k=3, trial_index=0):
    """
    Plot example neurons with highest and lowest pseudo-R² within a target population.
    Handles spikes/pred_rates provided as lists, numpy arrays, or tensors.
    """
    if config is None:
        raise ValueError("config is required")

    # Convert to tensors
    true_3d = _to_3d_tensor(spikes).float()
    pred_3d = _to_3d_tensor(pred_rates).float()

    # Determine neuron index range for population
    if 'neurons' in config and target in config['neurons']:
        start, end = map(int, config['neurons'][target])
    else:
        # Fallback to dataset yaml used by get_spike
        yml_path = f"./config/dataset/{config['data_type']}.yaml"
        with open(yml_path, "r") as f:
            data_config = yaml.safe_load(f)
        start, end = map(int, data_config['neurons'][target])

    # Compute per-neuron pseudo R² across all trials/timepoints
    pr2_all = _pseudo_r2_per_neuron(pred_3d, true_3d)
    pr2_slice = pr2_all[start:end]

    # Filter non-finite values
    finite_mask = torch.isfinite(pr2_slice)
    if not torch.any(finite_mask):
        print("Warning: No finite pseudo-R² values; skipping example plot")
        return plt.figure()

    pr2_valid = pr2_slice[finite_mask]
    idx_valid = torch.arange(start, end)[finite_mask]

    # Select top-k and bottom-k indices
    k_top = min(top_k, pr2_valid.numel())
    k_bottom = min(bottom_k, pr2_valid.numel())
    order = torch.argsort(pr2_valid)  # ascending
    low_ids = idx_valid[order[:k_bottom]].tolist()
    high_ids = idx_valid[order[-k_top:][::-1]].tolist()

    n_cols = max(len(high_ids), len(low_ids)) if max(len(high_ids), len(low_ids)) > 0 else 1
    n_rows = 2 if len(low_ids) > 0 else 1
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
    # Normalize axes to 2D array
    if n_rows == 1 and n_cols == 1:
        axes = np.array([[axes]])
    elif n_rows == 1:
        axes = np.array([axes])
    elif n_cols == 1:
        axes = np.array([[axes[0]], [axes[1]]])

    # Helper to plot a set
    def _plot_set(ax_row, neuron_ids, title_prefix):
        for ci, neuron_idx in enumerate(neuron_ids):
            ax = axes[ax_row, ci]
            t_true = true_3d[trial_index, :, neuron_idx].detach().cpu().numpy()
            t_pred = pred_3d[trial_index, :, neuron_idx].detach().cpu().numpy()
            pr2_val = pr2_all[neuron_idx].item()
            ax.plot(t_true, label='True')
            ax.plot(t_pred, label='Pred', linestyle='--')
            ax.set_title(f"{title_prefix} n={neuron_idx} pr2={pr2_val:.3f}")
            ax.set_xlabel('Time')
            ax.set_ylabel('Rate / Spike')
        # Hide any extra axes in this row
        for ci in range(len(neuron_ids), n_cols):
            fig.delaxes(axes[ax_row, ci])

    if len(high_ids) > 0:
        _plot_set(0, high_ids, 'High pR2')
    if len(low_ids) > 0:
        _plot_set(1 if n_rows == 2 else 0, low_ids, 'Low pR2')

    # Single legend
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper right')
    fig.suptitle(f"{target}: Example neurons by pseudo-R² (trial {trial_index})")
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    path = os.path.join(config['BASE_PATH'], 'prediction')
    os.makedirs(path, exist_ok=True)
    filepath = f'{path}/{target}_pseudo_r2_examples.png'
    plt.savefig(filepath)
    return fig
