""" A dataset reader that generate data from function (toy datsets)

Copyright (c) 2025 Anonymous Authors
"""
import os
from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from .reader import Reader


class RandomMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, activation_fn=nn.Tanh):
        super(RandomMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.activation = activation_fn()

        # Initialize weights randomly
        for layer in [self.fc1, self.fc2]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        return x
    

def non_linear_augmentation(xs, activation_fn=nn.Tanh):
    """
    Apply non-linear augmentation to the input tensor using two random MLPs.

    Args:
        xs (torch.Tensor): Input tensor of shape (batch_size, input_dim).
        activation_fn (nn.Module): Activation function to use in the MLPs.

    Returns:
        torch.Tensor: Augmented tensor of the same shape as input.
    """
    input_dim = xs.shape[1]
    hidden_dim = xs.shape[1]

    mlp = RandomMLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=input_dim, activation_fn=activation_fn)
    with torch.no_grad():
        xs = torch.tensor(xs.astype(np.float32))
        xs = mlp(xs)
    xs = xs.detach().numpy()

    return xs


def spiral_data(num_samples, sample_radius=1, noise_strength=0, non_linear_aug=False):
    assert (num_samples % 2) == 0, "spiral data is defined over even num_samples"
    
    half_num_samples = num_samples // 2
    angles = np.random.uniform(0, 2 * np.pi, half_num_samples)
    rs = angles / (2 * np.pi) * sample_radius
    shifted_angles = angles + np.pi / 2

    x1_l1 = rs * np.sin(angles)
    x2_l1 = rs * np.cos(angles)
    x_l1 = np.stack((x1_l1, x2_l1), axis=1)

    x1_l2 = rs * np.sin(shifted_angles)
    x2_l2 = rs * np.cos(shifted_angles)
    x_l2 = np.stack((x1_l2, x2_l2), axis=1)
    
    xs = np.concatenate((x_l1, x_l2), axis=0)

    if non_linear_aug:
        xs = non_linear_augmentation(xs)

    x1_noise = np.random.uniform(-sample_radius, sample_radius, num_samples) * noise_strength
    x2_noise = np.random.uniform(-sample_radius, sample_radius, num_samples) * noise_strength
    xs_noise = np.stack((x1_noise, x2_noise), axis=1)
    xs += xs_noise
    xs /= 2
    xs += 0.5

    ys_l1 = np.ones(half_num_samples)
    ys_l2 = np.zeros(half_num_samples)
    ys = np.concatenate((ys_l1, ys_l2), axis=0)

    return xs.astype(np.float32), ys.astype(np.int64)


def create_labels_from_nearest(xs, sp_xs, sp_ys):
    ys = np.zeros(xs.shape[0], dtype=sp_ys.dtype)

    for i, x in enumerate(xs):
        # 모든 sp_xs와 x 간의 유클리드 거리 계산
        distances = np.linalg.norm(sp_xs - x, axis=1)
        # 가장 가까운 sp_xs의 인덱스 찾기
        nearest_index = np.argmin(distances)
        # 해당 인덱스의 라벨 사용
        ys[i] = sp_ys[nearest_index]

    return ys


def grid_data(num_samples, valid_num_samples=16384):
    sqrt_num = int(np.sqrt(num_samples))
    x1, x2 = np.meshgrid(np.linspace(0, 1.0, sqrt_num),
                             np.linspace(0, 1.0, sqrt_num))
    xs = np.c_[x1.ravel(), x2.ravel()]

    sp_xs, sp_ys = spiral_data(valid_num_samples)
    ys = create_labels_from_nearest(xs, sp_xs, sp_ys)

    return xs.astype(np.float32), ys.astype(np.int64)


class ReaderToy(Reader):

    def __init__(
            self,
            name='spiral',
            split='train',
    ):
        """ generate datasets by function with toy dataset name 

        Args:
            name : name of toy dataset (ex. spiral)
            split : train, validation, test split name

        Returns:
            A list of coordinate (with image like dimension) and target tuples, class_to_idx mapping
        """
        super().__init__()
        assert name in ['spiral', 'grid', 'spiral_aug'], "toy dataset list : spiral, grid"
        if name == 'grid':
            num_samples=2500
        elif split == 'train':
            num_samples=16384
        elif split == 'validation':
            num_samples=4096
        else:
            num_samples=2500

        if name == 'spiral':
            coordinate_and_targets = spiral_data(num_samples)
        elif name == 'spiral_aug':
            coordinate_and_targets = spiral_data(num_samples, non_linear_aug=True)
        else: # 'grid'
            coordinate_and_targets = grid_data(num_samples)
        self.coordinate_and_targets = coordinate_and_targets
        self.samples = [(coordinate_and_targets[0][i].reshape(1, 1, 2), coordinate_and_targets[1][i]) for i in range(num_samples)]

    def __getitem__(self, index):
        cooridinate, target = self.samples[index]
        return cooridinate, target

    def __len__(self):
        return len(self.samples)
