import os
import logging
from torch.utils import data
import numpy as np
import yaml
import torch
from scipy.spatial.transform import Rotation as R

logger = logging.getLogger(__name__)


# Fields
class Field(object):
    ''' Data fields class.
    '''

    def load(self, data_path, idx, category):
        ''' Loads a data point.

        Args:
            data_path (str): path to data file
            idx (int): index of data point
            category (int): index of category
        '''
        raise NotImplementedError

    def check_complete(self, files):
        ''' Checks if set is complete.

        Args:
            files: files
        '''
        raise NotImplementedError


class Shapes3dDataset(data.Dataset):
    ''' 3D Shapes dataset class.
    '''

    def __init__(self, dataset_folder, fields, split=None,
                 categories=None, no_except=True, transform=None,adj_fn=None,mini_batch=1,val_mini_batch=1,kmode='full',number_iou_points=1000,rot_augment='aligned',rot_folder=None,cfg=None):
        ''' Initialization of the the 3D shape dataset.

        Args:
            dataset_folder (str): dataset folder
            fields (dict): dictionary of fields
            split (str): which split is used
            categories (list): list of categories to use
            no_except (bool): no exception
            transform (callable): transformation applied to data points
            mini_batch : number of mini batches
        '''
        # Attributes
        self.dataset_folder = dataset_folder
        self.fields = fields
        self.no_except = no_except
        self.transform = transform
        self.adj_fn=adj_fn
        self.mini_batch = mini_batch
        self.val_mini_batch=val_mini_batch
        self.number_iou_points=number_iou_points
        self.kmode = kmode
        self.rot_augment = rot_augment
        self.rot_folder = rot_folder
        self.cfg=cfg
        #print("ROT AUGMENT = ", rot_augment)
        #print("ROT FOLDER = ", rot_folder)
        # If categories is None, use all subfolders
        if categories is None:
            categories = os.listdir(dataset_folder)
            categories = [c for c in categories
                          if os.path.isdir(os.path.join(dataset_folder, c))]

        # Read metadata file
        metadata_file = os.path.join(dataset_folder, 'metadata.yaml')

        if os.path.exists(metadata_file):
            with open(metadata_file, 'r') as f:
                self.metadata = yaml.safe_load(f)
        else:
            self.metadata = {
                c: {'id': c, 'name': 'n/a'} for c in categories
            } 
        print(categories)
        # Set index
        for c_idx, c in enumerate(categories):
            self.metadata[c]['idx'] = c_idx
        
        # Get all models
        self.models = []
        for c_idx, c in enumerate(categories):
            subpath = os.path.join(dataset_folder, c)
            if not os.path.isdir(subpath):
                logger.warning('Category %s does not exist in dataset.' % c)

            split_file = os.path.join(subpath, split + '.lst')
            with open(split_file, 'r') as f:
                models_c = f.read().split('\n')
            
            self.models += [
                {'category': c, 'model': m}
                for m in models_c
            ]
        
    def __len__(self):
        ''' Returns the length of the dataset.
        '''
        return len(self.models)

    def item_by_category(self, category,model):
        ''' Returns an item of the dataset.

        Args:
            idx (int): ID of data point
        '''
        #print(category)
        #print(model)
        c_idx = self.metadata[category]['idx']
        idx=0
        
        model_path = os.path.join(self.dataset_folder, category, model)
        
        if self.rot_folder is not None and self.rot_augment != 'aligned':
            rotation_path = os.path.join(self.rot_folder, category, model, 'random_rotations.npz')
            rotation = np.load(rotation_path)
        data = {} 
        #print(model_path)
        for field_name, field in self.fields.items():
            try:
                #print(idx)
                #print(c_idx)
                field_data = field.load(model_path, idx, c_idx)
            except Exception:
                if self.no_except:
                    logger.warn(
                        'Error occured when loading field %s of model %s'
                        % (field_name, model)
                    )
                    return None
                else:
                    raise

            if isinstance(field_data, dict):
                for k, v in field_data.items():
                    if v.shape[-1]==3:
                        if self.rot_augment in ['so3','pca']:
                            r = rotation['so3']
                            v = np.matmul(v,r)
                        elif self.rot_augment =='z':
                            r = R.from_euler('z', rotation['z'], degrees=True).as_matrix()
                            v = np.matmul(v, r).astype(np.float32)
                    
                    if k is None:
                        data[field_name] = v
                    else:
                        data['%s.%s' % (field_name, k)] = v
            else:
                data[field_name] = field_data

        #print(data)
        if self.transform is not None:
            data = self.transform(data) #transform labels?
        
        data['category']=category
        data['model']=model
        #print("########################")
        #print(data)
        if 'inputs' in data.keys():
            data['graph']=self.adj_fn(data['inputs'],self.kmode)
            

            query_points=data['points']
            number_queries=query_points.shape[0]
            number_nodes=data['graph'].num_nodes()

            g_x,g_y=torch.meshgrid(torch.arange(number_queries//self.mini_batch),torch.arange(number_nodes))
            data['occ_tuples']=torch.cat((g_x.reshape(1,-1),g_y.reshape(1,-1)))

        if 'points_iou' in data.keys():
            #data['points_iou'].shape[0]
            
            perm = torch.randperm(data['points_iou'].shape[0])

            data['points_iou'] = data['points_iou'][perm]
            data['points_iou.occ'] = data['points_iou.occ'][perm]
            
            number_nodes=data['graph'].num_nodes()
            g_x,g_y=torch.meshgrid(torch.arange(self.number_iou_points//self.val_mini_batch),torch.arange(number_nodes))
            data['iou_tuples']=torch.cat((g_x.reshape(1,-1),g_y.reshape(1,-1)))
            
        return data
    
    def __getitem__(self, idx):
        ''' Returns an item of the dataset.

        Args:
            idx (int): ID of data point
        '''
        category = self.models[idx]['category']
        model = self.models[idx]['model']
        c_idx = self.metadata[category]['idx']
        
        model_path = os.path.join(self.dataset_folder, category, model)
        
        if self.rot_folder is not None and self.rot_augment != 'aligned':
            rotation_path = os.path.join(self.rot_folder, category, model, 'random_rotations.npz')
            rotation = np.load(rotation_path)
        data = {} 

        for field_name, field in self.fields.items():
            try:
                field_data = field.load(model_path, idx, c_idx)
            except Exception:
                if self.no_except:
                    logger.warn(
                        'Error occured when loading field %s of model %s'
                        % (field_name, model)
                    )
                    return None
                else:
                    raise

            if isinstance(field_data, dict):
                for k, v in field_data.items():
                    if v.shape[-1]==3:
                        if self.rot_augment in ['so3','pca']:
                            r = rotation['so3']
                            v = np.matmul(v,r)
                        elif self.rot_augment =='z':
                            r = R.from_euler('z', rotation['z'], degrees=True).as_matrix()
                            v = np.matmul(v, r).astype(np.float32)
                    
                    if k is None:
                        data[field_name] = v
                    else:
                        data['%s.%s' % (field_name, k)] = v
            else:
                data[field_name] = field_data


        if self.transform is not None:
            data = self.transform(data) #transform labels?
        
        data['category']=category
        data['model']=model
        
        if 'inputs' in data.keys():
            data['graph']=self.adj_fn(data['inputs'],self.kmode)
            

            query_points=data['points']
            number_queries=query_points.shape[0]
            number_nodes=data['graph'].num_nodes()

            g_x,g_y=torch.meshgrid(torch.arange(number_queries//self.mini_batch),torch.arange(number_nodes))
            data['occ_tuples']=torch.cat((g_x.reshape(1,-1),g_y.reshape(1,-1)))

        if 'points_iou' in data.keys():
            #data['points_iou'].shape[0]
            
            perm = torch.randperm(data['points_iou'].shape[0])

            data['points_iou'] = data['points_iou'][perm]
            data['points_iou.occ'] = data['points_iou.occ'][perm]
            
            number_nodes=data['graph'].num_nodes()
            g_x,g_y=torch.meshgrid(torch.arange(self.number_iou_points//self.val_mini_batch),torch.arange(number_nodes))
            data['iou_tuples']=torch.cat((g_x.reshape(1,-1),g_y.reshape(1,-1)))
            
        return data

    def get_model_dict(self, idx):
        return self.models[idx]

    def test_model_complete(self, category, model):
        ''' Tests if model is complete.

        Args:
            model (str): modelname
        '''
        model_path = os.path.join(self.dataset_folder, category, model)
        files = os.listdir(model_path)
        for field_name, field in self.fields.items():
            if not field.check_complete(files):
                logger.warn('Field "%s" is incomplete: %s'
                            % (field_name, model_path))
                return False

        return True


def collate_remove_none(batch):
    ''' Collater that puts each data field into a tensor with outer dimension
        batch size.

    Args:
        batch: batch
    '''

    batch = list(filter(lambda x: x is not None, batch))
    return data.dataloader.default_collate(batch)


def worker_init_fn(worker_id):
    ''' Worker init function to ensure true randomness.
    '''
    random_data = os.urandom(4)
    base_seed = int.from_bytes(random_data, byteorder="big")
    np.random.seed(base_seed + worker_id)
