import os
import torch
from utils import normalize

class Synthetic:
    def __init__(self, cfg):
        self.cfg = cfg
        self.x_dim = cfg.get('x_dim', 256)
        self.y_dim = cfg.get('y_dim', 128)

        self.data_path = cfg['dataset_path']
        if not os.path.exists(self.data_path):
            os.makedirs(self.data_path)
            self.__generate_data()
        else:
            self.__load_data()

    def __generate_data(self):
        n_train = self.cfg.get('n_train', 10000)
        n_test  = self.cfg.get('n_test' , 2000)

        X_train = torch.randn((n_train, self.x_dim)) * 6 - 3
        y_train = self.__sample_conditional_y(X_train)

        X_test = torch.randn((n_test, self.x_dim)) * 6 - 3
        y_test = self.__sample_conditional_y(X_test)

        torch.save(X_train, os.path.join(self.data_path, 'X_train.pt'))
        torch.save(y_train, os.path.join(self.data_path, 'y_train.pt'))
        torch.save(X_test,  os.path.join(self.data_path, 'X_test.pt'))
        torch.save(y_test,  os.path.join(self.data_path, 'y_test.pt'))

        self.X_train = X_train
        self.y_train = y_train
        self.X_test  = X_test
        self.y_test  = y_test

    def __sample_conditional_y(self, X):
        n = X.shape[0]
        y = torch.zeros((n, self.y_dim))

        for i in range(n):
            x = X[i]

            # Define "region" based on X to determine the generation mode
            r = torch.norm(x[:8])  # Calculate r using the first 8 dimensions
            # Use principal component directions as a substitute for "angle"
            angle_components = x[:16].reshape(4, 4).sum(dim=1)  # Simplified representation

            # 2. Choose different generation modes based on the "region"
            if r < 1.0:  # Region 1: Multimodal
                # Determine mode based on certain features of X
                mode_selector = torch.sigmoid(x[0] + x[1])  # Randomness based on X features
                if mode_selector > 0.7:
                    # Mode 1: Low-rank structure + nonlinear transformation
                    weight = torch.randn(self.y_dim, 8) * 0.1
                    y[i] = torch.tanh(weight @ x[:8]) + 0.1 * torch.randn(self.y_dim)
                else:
                    # Mode 2: Different low-rank structure
                    weight = torch.randn(self.y_dim, 8) * 0.15
                    y[i] = torch.sin(weight @ x[8:16]) + 0.2 * torch.randn(self.y_dim)

            elif r < 2.0:  # Region 2: Heteroscedastic Gaussian
                # Mean: Linear transformation of X (low-rank)
                mean_weight = torch.randn(self.y_dim, 16) * (0.5 + 0.3 * torch.abs(torch.sin(angle_components.sum())))
                mean = mean_weight @ x[:16]

                # Variance: Varies with X features
                std = 0.1 + 0.4 * torch.sigmoid(angle_components.mean())

                y[i] = mean + std * torch.randn(self.y_dim)

            else:  # Region 3: Mixture of multimodal
                # Mixing weights based on X features
                probs = torch.softmax(x[:3], dim=0)  # Use first 3 dimensions to determine mixture proportions
                mode = torch.multinomial(probs, 1).item()

                if mode == 0:
                    # Mode A: Sparse structure
                    mask = (torch.rand(self.y_dim) > 0.7).float()
                    y[i] = mask * (0.5 * torch.randn(self.y_dim)) + (1-mask) * 0.1 * torch.randn(self.y_dim)
                elif mode == 1:
                    # Mode B: Smooth variation
                    weight = torch.randn(self.y_dim, 32) * 0.08
                    y[i] = weight @ x[:32] + 0.15 * torch.randn(self.y_dim)
                else:
                    # Mode C: Piecewise constant
                    y[i] = 2.0 * torch.sign(x[:self.y_dim]) + 0.2 * torch.randn(self.y_dim)

            # Add global noise
            y[i] += 0.05 * torch.randn(self.y_dim)

        return y

    def __load_data(self):
        self.X_train = torch.load(os.path.join(self.data_path, 'X_train.pt'))
        self.y_train = torch.load(os.path.join(self.data_path, 'y_train.pt'))
        self.X_test  = torch.load(os.path.join(self.data_path, 'X_test.pt' ))
        self.y_test  = torch.load(os.path.join(self.data_path, 'y_test.pt' ))

    def get_loader(self, split=None):
        return self.get_data(split)

    def get_data(self, split):
        return eval(f'self.{split}_data()')

    def train_data(self):
        return normalize(self.X_train), normalize(self.y_train)

    def test_data(self):
        return normalize(self.X_test), normalize(self.y_test)