import random

import torch
import numpy as np

def to_torch(x):
    return torch.from_numpy(x).float().unsqueeze(-1)


class SINE():
    def __init__(self, x_min=-5.0, x_max=5.0):
        self.x_min = x_min
        self.x_max = x_max

        self.amplitude_min = 0.1
        self.amplitude_max = 5.0

        self.phase_min = 0.0
        self.phase_max = np.pi

        # self.frequency_min = 1.0
        # self.frequency_max = 3.0
        
    def true_function(self, X, amplitude, phase, frequency):
        return amplitude * np.sin(phase + frequency * X)

    def sample_data(self, batch_size=16, num_samples=5, mode="train"):
        # if mode == "train":
        #     amplitude = []
        #     for i in range(batch_size):
        #         if random.random() < 0.95:
        #             amplitude.append(np.random.uniform(self.amplitude_min, self.amplitude_min+0.95))
        #         else:
        #             amplitude.append(np.random.uniform(self.amplitude_max-0.05, self.amplitude_max))
        #     amplitude = np.array(amplitude).reshape(-1, 1)    
        # else:
        # amplitude = np.random.uniform(self.amplitude_min, self.amplitude_max, size=(batch_size, 1))
        if mode == "correlated":
            x = np.random.uniform(low=[self.amplitude_min, self.phase_min], high=[self.amplitude_max, self.phase_max], size=2)
            xs = [x]
            for _ in range(batch_size-1):
                amp_dir, phase_dir = np.random.uniform(0.0, 1.0, 2)
                if amp_dir < 0.5 and phase_dir < 0.5:
                    xs.append(np.clip(xs[-1] + np.array([-0.1, -0.1]) + np.random.normal(0, 0.05, size=2), np.array([self.amplitude_min, self.phase_min]), 
                                                                                                           np.array([self.amplitude_max, self.phase_max])))
                elif amp_dir < 0.5 and phase_dir >= 0.5:
                    xs.append(np.clip(xs[-1] + np.array([-0.1, 0.1]) + np.random.normal(0, 0.05, size=2), np.array([self.amplitude_min, self.phase_min]), 
                                                                                                          np.array([self.amplitude_max, self.phase_max])))
                elif amp_dir >= 0.5 and phase_dir < 0.5:
                    xs.append(np.clip(xs[-1] + np.array([0.1, -0.1]) + np.random.normal(0, 0.05, size=2), np.array([self.amplitude_min, self.phase_min]), 
                                                                                                          np.array([self.amplitude_max, self.phase_max])))
                else:
                    xs.append(np.clip(xs[-1] + np.array([0.1, 0.1]) + np.random.normal(0, 0.05, size=2), np.array([self.amplitude_min, self.phase_min]), 
                                                                                                          np.array([self.amplitude_max, self.phase_max])))
            xs = np.stack(xs)
            amplitude, phase = xs[:, 0].reshape(-1, 1), xs[:, 1].reshape(-1, 1)
            
        elif mode == "train":
            amplitude = []
            for _ in range(batch_size):
                if random.random() < 0.95:
                    amplitude.append(np.random.uniform(self.amplitude_min, self.amplitude_min+0.95))
                else:
                    amplitude.append(np.random.uniform(self.amplitude_max-0.05, self.amplitude_max))
            amplitude = np.array(amplitude).reshape(-1, 1)
            # amplitude = np.random.uniform(self.amplitude_min, self.amplitude_min+0.95, size=(batch_size, 1))
        elif mode == "uniform":
            amplitude = np.random.uniform(self.amplitude_min, self.amplitude_max, size=(batch_size, 1))
        elif mode == "skewed":
            # amplitude = []
            # for _ in range(int(batch_size * 0.95)):
            #     amplitude.append(np.random.uniform(self.amplitude_min, self.amplitude_min+0.95))
            # for _ in range(int(batch_size * 0.05)):
            #     amplitude.append(np.random.uniform(self.amplitude_max-0.05, self.amplitude_max))
            # amplitude = np.array(amplitude).reshape(-1, 1)
            amplitude = []
            for _ in range(batch_size):
                if random.random() < 0.95:
                    amplitude.append(np.random.uniform(self.amplitude_min, self.amplitude_min+0.95))
                else:
                    amplitude.append(np.random.uniform(self.amplitude_max-0.05, self.amplitude_max))
            amplitude = np.array(amplitude).reshape(-1, 1)    
        elif mode == "inverse":
            amplitude = []
            for _ in range(int(batch_size * 0.05)):
                amplitude.append(np.random.uniform(self.amplitude_min, self.amplitude_min+0.95))
            for _ in range(int(batch_size * 0.95)):
                amplitude.append(np.random.uniform(self.amplitude_max-0.05, self.amplitude_max))
            amplitude = np.array(amplitude).reshape(-1, 1)
        elif mode == "gaussian0.5":
            amplitude = np.clip(np.random.normal(0.5, 1.0, size=(batch_size, 1)), self.amplitude_min, self.amplitude_max)
        elif mode == "gaussian1.0":
            amplitude = np.clip(np.random.normal(1.0, 1.0, size=(batch_size, 1)), self.amplitude_min, self.amplitude_max)  
        elif mode == "gaussian2.5":
            amplitude = np.clip(np.random.normal(2.5, 1.0, size=(batch_size, 1)), self.amplitude_min, self.amplitude_max)
        elif mode == "gaussian4.0":
            amplitude = np.clip(np.random.normal(4.0, 1.0, size=(batch_size, 1)), self.amplitude_min, self.amplitude_max)
        elif mode == "gaussian4.5":
            amplitude = np.clip(np.random.normal(4.5, 1.0, size=(batch_size, 1)), self.amplitude_min, self.amplitude_max)
        elif mode == "bimodal":
            weights = np.array([0.5, 0.5])
            means = np.array([1.0, 4.0]) # Mean for first and second Gaussian
            sigmas = np.array([0.5, 0.5]) # Covariance matrix for first and second Gaussian

            components = np.random.choice(2, size=batch_size, p=weights)
            amplitude = []
            for component in components:
                sample = np.clip(np.random.normal(means[component], sigmas[component]), self.amplitude_min, self.amplitude_max)
                amplitude.append(sample)
            amplitude = np.array(amplitude).reshape(-1, 1)
        
        if not mode == "correlated":
            phase = np.random.uniform(self.phase_min, self.phase_max, size=(batch_size, 1))
        # frequency = np.random.uniform(self.frequency_min, self.frequency_max, size=(batch_size, 1))
        frequency = 1.0

        if mode == "plot":
            X = np.expand_dims(np.linspace(self.x_min, self.x_max, num=1000), axis=0).repeat(batch_size, axis=0)
        else:
            X = np.random.uniform(self.x_min, self.x_max, size=(batch_size, num_samples))
        # y = amplitude * np.sin(phase + frequency * X)
        y = self.true_function(X, amplitude, phase, frequency)

        X = to_torch(X)
        y = to_torch(y)

        return X, y
