import torch

from .shapenet import ShapeNet
from .magcorenet import MagcoreNet, MagcoreField
from .stressnet import StressField, StressNet_cls
from .airnet import AirNet_cls, AirField
from .elecnet import ElecNet_cls, ElecField

class AxisScaling(object):
    def __init__(self, interval=(0.75, 1.25), jitter=True):
        assert isinstance(interval, tuple)
        self.interval = interval
        self.jitter = jitter
        
    def __call__(self, surface, point):
        scaling = torch.rand(1, 3) * 0.5 + 0.75
        surface = surface * scaling
        point = point * scaling

        scale = (1 / torch.abs(surface).max().item()) * 0.999999
        surface *= scale
        point *= scale

        if self.jitter:
            surface += 0.005 * torch.randn_like(surface)
            surface.clamp_(min=-1, max=1)

        return surface, point
    
    
class BulkScaling(object):
    def __init__(self, min, range, jitter=True):
        self.min = min
        self.range = range
        self.jitter = jitter
        
    def __call__(self, bulk, point):
        scaling = torch.rand(1, 3) * self.range + self.min
        bulk = bulk * scaling
        point = point * scaling

        scale = (1 / torch.abs(bulk).max().item()) * 0.999999
        bulk *= scale
        point *= scale
        
        if self.jitter:
            bulk += 0.005 * torch.randn_like(bulk)
            bulk.clamp_(min=-1, max=1)

        return bulk, point


def build_shape_surface_occupancy_dataset(split, args):
    if split == 'train':
        # transform = #transforms.Compose([
        transform = AxisScaling((0.75, 1.25), True)
        # ])
        return ShapeNet(args.data_path, split=split, transform=transform, sampling=True, num_samples=1024, return_surface=True, surface_sampling=True, pc_size=args.point_cloud_size)
    elif split == 'val':
        # return ShapeNet(args.data_path, split=split, transform=None, sampling=True, num_samples=1024, return_surface=True, surface_sampling=True, pc_size=args.point_cloud_size)
        return ShapeNet(args.data_path, split=split, transform=None, sampling=False, return_surface=True, surface_sampling=True, pc_size=args.point_cloud_size)
    else:
        return ShapeNet(args.data_path, split=split, transform=None, sampling=False, return_surface=True, surface_sampling=True, pc_size=args.point_cloud_size)
    
def build_magcore_shape_dataset(split, args):
    # transform = BulkScaling(0.9, 0.2)
    if split == 'train':
        return MagcoreNet(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, comp = args.comp)
    elif split == 'val':
        return MagcoreNet(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.val_num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, comp = args.comp)


def build_Bfield_dataset(split, args, mesh = True):
    if split == 'train':
        return MagcoreField(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, normalize = args.normalize, mesh = mesh)
    elif split == 'val':
        return MagcoreField(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.val_num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, normalize = args.normalize, mesh = mesh)


def build_stress_shape_dataset(split, args):
    if split == 'train':
        return StressNet_cls(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True)
    elif split == 'val':
        return StressNet_cls(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.val_num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True)  


def build_Sfield_dataset(split, args, mesh = True):
    if split == 'train':
        return StressField(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, normalize = args.normalize, mesh = mesh, model = args.model, use_VAE = args.use_VAE)
    elif split == 'val':
        return StressField(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.val_num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, normalize = args.normalize, mesh = mesh, model = args.model, use_VAE = args.use_VAE)    
    
    
def build_airfran_shape_dataset(split, args):
    if split == 'train':
        return AirNet_cls(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True)
    elif split == 'val':
        return AirNet_cls(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.val_num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True)
    
    
def build_Pfield_dataset(split, args, mesh = True):
    if split == 'train':
        return AirField(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, normalize = args.normalize, mesh = mesh, model = args.model, use_VAE = args.use_VAE)
    elif split == 'val':
        return AirField(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.val_num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, normalize = args.normalize, mesh = mesh, model = args.model, use_VAE = args.use_VAE)


def build_elec_shape_dataset(split, args):
    if split == 'train':
        return ElecNet_cls(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True)
    elif split == 'val':
        return ElecNet_cls(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.val_num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True)
    
    
def build_Efield_dataset(split, args, mesh = True):
    if split == 'train':
        return ElecField(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, normalize = args.normalize, mesh = mesh)
    elif split == 'val':
        return ElecField(args.data_path, split=split, transform=None, geom_types=args.geom_types, num_geom=args.val_num_geom, num_points=args.num_points, num_queries=args.num_queries, sampling = True, normalize = args.normalize, mesh = mesh)
    