import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt


class ThreeBitFF:
    def __init__(self, params):
        self.params = params
        self.verbose = params.get('verbose', False)
        self.rng = np.random.RandomState(params['seed'])
        self.is_trained = False

        # Adjust the output dimensions if auxiliary loss is enabled
        if params.get('aux_loss', False):
            params['dim_output'] += 1

    def generate_trials(self):
        """
        Generate synthetic input/output trials for the flip-flop task,
        with optional auxiliary outputs.
        """
        p = self.params
        n_batch, n_time, n_bits = p['n_batch'], p['n_time'], p['n_bits']
        p_flip, n_fps = p['p_flip'], p['n_fps']
        aux = p.get('aux_loss', False)

        # unsigned pulses then random signs
        unsigned = self.rng.binomial(n_fps // 2, p_flip, (n_batch, n_time, n_bits))
        unsigned[:, 0, :] = 1
        signs = 2 * self.rng.binomial(1, 0.5, unsigned.shape) - 1
        inputs = unsigned * signs

        # compute flip-flop outputs
        output = np.zeros_like(inputs)
        for b in range(n_batch):
            for bit in range(n_bits):
                flips = np.where(inputs[b, :, bit] != 0)[0]
                for t in flips:
                    output[b, t:, bit] = inputs[b, t, bit]

        # reproduce inputs instead of outputs if requested
        if p.get('input_reproducing', False):
            output = inputs

        # optional auxiliary output: time since last nonzero input
        if aux:
            aux_out = np.zeros((n_batch, n_time, 1))
            for b in range(n_batch):
                last = -1
                for t in range(n_time):
                    if inputs[b, t, 0] != 0:
                        last = t
                    aux_out[b, t, 0] = (t - last) if last >= 0 else 0
            output = np.concatenate([output, aux_out], axis=-1)

        return {'inputs': inputs, 'outputs': output}

    def visualize_trials(self, stim):
        """Show three example trials of inputs vs. outputs."""
        fig, axes = plt.subplots(3, 1, figsize=(3, 4), sharey=True)
        for i, ax in enumerate(axes):
            ax.plot(stim['inputs'][i, :, 0], linestyle='-')
            ax.plot(stim['outputs'][i, :, 0], linestyle='--')
            ax.axis('off')
        plt.yticks([-1, 1])
        plt.xlabel('Time')
        plt.show()

    def get_train_loader(self):
        """Return a DataLoader over a single batch of generated trials."""
        stim = self.generate_trials()
        x = torch.from_numpy(stim['inputs'].astype(np.float32))
        y = torch.from_numpy(stim['outputs'].astype(np.float32))
        dataset = TensorDataset(x, y)
        return DataLoader(dataset, batch_size=x.shape[0], shuffle=True)

    def generate_train_data(self):
        """Return one batch of generated trials as tensors."""
        stim = self.generate_trials()
        x = torch.from_numpy(stim['inputs'].astype(np.float32))
        y = torch.from_numpy(stim['outputs'].astype(np.float32))
        return x, y
