"""
this actually requires a lot of memory, because the generated sequences are long!
"""
import torch
import matplotlib.pyplot as plt
import numpy as np
import pickle
import os


title_font_size = 7
axis_label_font_size = 8
ticks_font_size = 6

legend_settings = {'fontsize': 5, 'frameon': True, 'handlelength': 0.7}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def generate_sample(length, device):
    """Generate a random sample of given length"""
    return torch.randn(length, device=device)

def calculate_entropy(sample, start, end, tau, apply_transformation=True):
    """Calculate entropy for a portion of the sample from start to end"""
    ts = torch.arange(start, end, device=sample.device)

    if apply_transformation:
        m_base = -2 * torch.log(ts/tau + 1)
        a2 = 2*torch.log(ts/tau + 1) + 1
        x = m_base + sample[start:end] * torch.sqrt(a2)
    else:
        x = sample[start:end]

    # Use logsumexp for numerical stability
    lse = torch.logsumexp(x, dim=0)
    exp_logits = torch.exp(x - lse)
    entropy = lse - torch.sum(exp_logits * x)
    return entropy.item()

def get_entropy_data(t_values_ent, tau, Delta, n_samples=1000, apply_transformation=True):
    """Calculate or load entropy data for given parameters"""
    if apply_transformation:
        filename = f'./entropy_data_tau{tau}_Delta{Delta}_nsamples{n_samples}.pkl'
    else:
        filename = f'./entropy_data_tau{tau}_Delta{Delta}_nsamples{n_samples}_std.pkl'

    # Check if data already exists
    if os.path.exists(filename):
        print(f"Loading entropy data from {filename}")
        with open(filename, 'rb') as f:
            data = pickle.load(f)
            return data['mean_entropy'], data['std_entropy'], data['stderr_entropy']

    print(f"Calculating entropy data for tau={tau}, Delta={Delta}")
    # Pre-calculate the maximum length needed
    max_length = int(max(t_values_ent) * Delta)

    # For each sample, calculate entropy for all t values
    entropy_values = []

    for _ in range(n_samples):
        sample = generate_sample(length=max_length, device=device)
        for t in t_values_ent:
            entropy = calculate_entropy(sample, start=t, end=int(t*Delta), tau=tau, apply_transformation=apply_transformation)
            entropy_values.append(entropy)

    # Reshape entropy_values to (n_samples, len(t_values_ent))
    entropy_values = np.array(entropy_values).reshape(n_samples, len(t_values_ent))

    # Calculate mean and standard error for entropy
    mean_entropy = np.mean(entropy_values, axis=0)
    stderr_entropy = np.std(entropy_values, axis=0) / np.sqrt(n_samples)
    std_entropy = np.std(entropy_values, axis=0)

    # Save the data
    data = {
        'mean_entropy': mean_entropy,
        'std_entropy': std_entropy,
        'stderr_entropy': stderr_entropy
    }
    with open(filename, 'wb') as f:
        pickle.dump(data, f)
    print(f"Saved entropy data to {filename}")

    return mean_entropy, std_entropy, stderr_entropy

def plot_all():
    t_values_ent = [1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000]
    for apply_transformation, suffix in [(True, ''), (False, '_std')]:
        if suffix == '':
            colors = ['#648fff', '#dc267F', '#ffb000']  # Dark gray, medium gray, light gray
        else:
            colors = ['#648fff', '#dc267f', '#ffb000']  # Dark gray, medium gray, light gray
        fig, axs = plt.subplots(1, 2, figsize=(5.5, 2.2))

        # Generate plots for all combinations (reversed order for legend)
        for i, (tau, Delta) in enumerate([(10, 10), (10, 5), (10, 2)]):
            print(f"Processing tau={tau}, Delta={Delta}, transformed={apply_transformation}")

            # Get entropy data
            mean_entropy, std_entropy, stderr_entropy = get_entropy_data(t_values_ent, tau, Delta, apply_transformation=apply_transformation)
            log10_t_ent = np.log10(t_values_ent)

            # Plot on left panel
            axs[0].plot(log10_t_ent, mean_entropy, 'o-', color=colors[i], label=f'Δ={Delta}', markersize=2, linewidth=0.75)
            axs[0].fill_between(log10_t_ent,
                            mean_entropy - stderr_entropy,
                            mean_entropy + stderr_entropy,
                            alpha=0.2, color=colors[i])

            # best fit line
            m_ent, b_ent = np.polyfit(log10_t_ent, mean_entropy, 1)
            axs[0].plot(log10_t_ent, m_ent*log10_t_ent + b_ent, '--', color=colors[i], linewidth=0.25)

            # Plot on right panel
            entropy_squared = mean_entropy**2
            entropy_squared_error = 2 * np.abs(mean_entropy) * stderr_entropy
            axs[1].plot(log10_t_ent, entropy_squared, 'o-', color=colors[i], markersize=2, linewidth=0.8)
            axs[1].fill_between(log10_t_ent,
                            entropy_squared - entropy_squared_error,
                            entropy_squared + entropy_squared_error,
                            alpha=0.2, color=colors[i])

            # best fit line
            m_ent_sq, b_ent_sq = np.polyfit(log10_t_ent, entropy_squared, 1)
            axs[1].plot(log10_t_ent, m_ent_sq*log10_t_ent + b_ent_sq, '--', color=colors[i], linewidth=0.5)

        axs[0].set_xlabel('log₁₀(t)', fontsize=axis_label_font_size - 2)
        axs[0].set_ylabel('E$[H_{t}^{t\Delta}]$', fontsize=axis_label_font_size)
        axs[0].set_title(f'Expected entropy', fontsize=title_font_size)
        axs[0].legend(fontsize=legend_settings['fontsize'])

        axs[1].set_xlabel('log₁₀(t)', fontsize=axis_label_font_size - 2)
        axs[1].set_ylabel('E$[H_{t}^{t\Delta}]^2$', fontsize=axis_label_font_size)
        axs[1].set_title(f'Square expected entropy', fontsize=title_font_size)
        for ax in axs:
            ax.tick_params(axis='both', labelsize=ticks_font_size)

        plt.tight_layout()
        filename = f'./plot{suffix}.pdf'
        plt.savefig(filename)
        plt.close()
        print(f"Saved plot to {filename}")

if __name__ == "__main__":
    plot_all()