import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from functools import cached_property
import itertools
import matplotlib.pyplot as plt


class SineWaveDataset(Dataset):
    def __init__(self, params):
        self.params = params
        self.num_channels = self.dim_in = self.dim_out = params['num_channels']
        self.freq_range = params['freq_range']
        self.freq_num = params['freq_num']
        self.frequencies = np.round(
            np.linspace(self.freq_range[0], self.freq_range[1], self.freq_num),
            2
        )
        self.dt = params['dt']
        self.sequence_length = params['sequence_length']
        self.num_samples = params['num_samples']
        self.t = np.arange(0, self.sequence_length * self.dt, self.dt)

        if self.num_channels > self.freq_num:
            raise ValueError("num_channels must not exceed number of frequencies")

    @cached_property
    def all_freq_combinations(self):
        return list(itertools.product(self.frequencies, repeat=self.num_channels))

    @cached_property
    def freq_index(self):
        total = len(self.all_freq_combinations)
        if total <= self.num_samples:
            return np.arange(total)
        return np.random.choice(total, self.num_samples, replace=False)

    def __len__(self):
        return min(len(self.all_freq_combinations), self.num_samples)

    def __getitem__(self, idx):
        freq = self.all_freq_combinations[self.freq_index[idx]]

        # inputs: each column is a sine wave frequency, repeated over time
        X = torch.tensor(freq, dtype=torch.float32) \
               .unsqueeze(1) \
               .repeat(1, self.sequence_length) \
               .T

        # targets: sine waves over time for each frequency
        y_vals = np.array([np.sin(2 * np.pi * f * self.t) for f in freq])
        y = torch.tensor(np.round(y_vals, 2), dtype=torch.float32).T

        if self.params.get('input_reproducing', False):
            X = y

        return X, y

    def get_train_loader(self, batch_size: int = 32, shuffle: bool = True):
        return DataLoader(self, batch_size=batch_size, shuffle=shuffle)

    def plot_samples(self, idx: int = 0):
        X, y = self[idx]
        plt.figure(figsize=(10, 6))
        for j, freq in enumerate(X[:, 0].numpy()):
            plt.scatter(self.t, y[j].numpy(), label=f'Freq {freq:.2f} Hz')
        plt.xlabel('Time')
        plt.ylabel('Amplitude')
        plt.legend()
        plt.tight_layout()
        plt.show()
