# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Tuple

import dgl
import pathlib
import torch
from dgl import DGLGraph
from torch import Tensor
from torch.utils.data import random_split, DataLoader, Dataset
from tqdm import tqdm

from im2mesh import onet

from im2mesh.config import get_inputs_field
from se3_transformer.data_loading.data_module import DataModule
from se3_transformer.data_loading.data_module import _get_dataloader 
from se3_transformer.model.basis import get_basis
from se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores

method_dict={'onet':onet}

def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
    x = qm9_graph.ndata['pos']
    src, dst = qm9_graph.edges()
    rel_pos = x[dst] - x[src]
    return rel_pos


def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
    len_full = len(full_dataset)
    len_train = 100_000
    len_test = int(0.1 * len_full)
    len_val = len_full - len_train - len_test
    return len_train, len_val, len_test


def singlepoint2graph(points,kmode='full'):
    if kmode == 'full':
        number_points=points.shape[0]
        adj=torch.arange(number_points)
        edges=torch.cartesian_prod(adj,adj) 
        graph=dgl.graph((edges[:,1],edges[:,0]))
        if isinstance(points,torch.Tensor):
            graph=graph.to(device=points.device)
            graph.ndata['pos']=points
        else:
            graph.ndata['pos']=torch.from_numpy(points)
    else:
        k = int(kmode)
        if not isinstance(points,torch.Tensor):
            points = torch.from_numpy(points)
        distances = torch.norm(points.unsqueeze(0) - points.unsqueeze(1),dim=2,p=None)
        _, knn = torch.topk(distances,k,dim=1,largest=False)
        
        src = knn.reshape(-1)
        dst = torch.arange(points.shape[0],device=points.device).repeat_interleave(k)
        graph = dgl.graph((src,dst))
        graph.ndata['pos'] = points
    return graph
def point2graph(samples,mode='full'):
    graph_list=[]
    samples_len = []
    last_index = -1
    for sample in samples:    
        if 'inputs' in sample.keys():
            points=sample['inputs']
            number_points=points.shape[0]
            adj=torch.arange(number_points)
            edges=torch.cartesian_prod(adj,adj)
            graph=dgl.graph((edges[:,0],edges[:,1]))
            graph.ndata['pos']=torch.from_numpy(points)
            graph_list.append(graph)

    return graph_list
def ShapeNet(mode,cfg,return_idx=False, return_category=False,adj_fun=None,mini_batch=1,val_mini_batch=1,kmode='full',number_iou_points=1000):

    ''' Returns the dataset.

    Args:
        model (nn.Module): the model which is used
        cfg (dict): config dictionary
        return_idx (bool): whether to include an ID field
    '''
    method = cfg['method']
    dataset_type = cfg['data']['dataset']
    dataset_folder = cfg['data']['path']
    
    if str(dataset_folder).split("/")[-1] == 'synthetic_room_dataset':
        print("MPIKA SYNTHETIC")
        import im2mesh.data_scene as data
    else:
        import im2mesh.data_shape as data

    categories = cfg['data']['classes']
    
    rot_folder = cfg['data']['rotation_folder']
    rot_augment = cfg['data']['rotation']
    # Get split
    splits = {
        'train': cfg['data']['train_split'],
        'val': cfg['data']['val_split'],
        'test': cfg['data']['test_split'],
    }

    split = splits[mode]

    # Create dataset
    # Dataset fields
    # Method specific fields (usually correspond to output)
    fields = method_dict[method].config.get_data_fields(mode, cfg)
    # Input fields
    inputs_field = get_inputs_field(mode, cfg)
    
    

    if inputs_field is not None:
        fields['inputs'] = inputs_field

    if return_idx:
        fields['idx'] = data.IndexField()

    if return_category:
        fields['category'] = data.CategoryField()

    dataset = data.Shapes3dDataset(
        dataset_folder, fields,
        split=split,
        categories=categories,
        adj_fn = adj_fun,
        mini_batch = mini_batch,
        val_mini_batch=val_mini_batch,
        kmode = kmode,
        number_iou_points=number_iou_points,
        rot_folder=rot_folder,
        rot_augment=rot_augment,
        cfg=cfg
        )
    return dataset 

class  ShapeNetModule(DataModule):
    """
    Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
    Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
    This includes all the molecules from QM9 except the ones that are uncharacterized.
    """
    NODE_FEATURE_DIM = 1
    EDGE_FEATURE_DIM = 0
    
    def __init__(self,
                 data_dir: pathlib.Path='/Datasets/ShapeNet',
                 task: str ='recon',
                 batch_size: int = 240,
                 num_workers: int = 4,
                 num_degrees: int = 4,
                 amp: bool = False,
                 precompute_bases: bool = False,
                 cfg: dict = None,
                 adj_fun=None,
                 mini_batch=1,
                 kmode='full',
                 fixed_percentage=False,
                 **kwargs):
        self.data_dir = cfg['data']['path']  # This needs to be before __init__ so that prepare_data has access to it
        
        if str(self.data_dir).split("/")[-1] == 'synthetic_room_dataset':
            print("MPIKA SYNTHETIC")
            #global data
            import im2mesh.data_scene as data
        else:
            #global data
            import im2mesh.data_shape as data

        super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
        print("Cuda Devices: ", torch.cuda.device_count())
        self.amp = amp
        if adj_fun is None:
            self.adj_fun=singlepoint2graph
        else:
            sef.adj_fun=adj_fun
        self.task = task
        self.batch_size = batch_size
        self.num_degrees = num_degrees
        self.mini_batch = mini_batch
        self.val_mini_batch=kwargs['val_mini_batch'] #4
        self.iou_number_points=kwargs['iou_number_points']  #1000
        self.kmode=kmode
        print("IOU Number Points", self.iou_number_points)
        self.fixed_percentage_ones=fixed_percentage
        qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
        if precompute_bases:
            print("Not Implemented yet")
            #bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
            #full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
            #                                         num_workers=num_workers, **qm9_kwargs)
        else:
            self.train_dataset = ShapeNet('train',cfg,adj_fun=self.adj_fun,mini_batch=self.mini_batch,val_mini_batch=self.val_mini_batch,kmode=kmode,number_iou_points=self.iou_number_points)
            self.val_dataset= ShapeNet('val',cfg,adj_fun=self.adj_fun,mini_batch=self.mini_batch,val_mini_batch=self.val_mini_batch,kmode=kmode, number_iou_points=self.iou_number_points)
            self.test_dataset=ShapeNet('test',cfg,adj_fun=self.adj_fun,mini_batch=self.mini_batch,val_mini_batch=self.val_mini_batch,kmode=kmode, number_iou_points=self.iou_number_points)

    def train_dataloader(self,shuf=True)-> DataLoader:
        #print(self.train_dataset)
        return _get_dataloader(self.train_dataset,shuffle=shuf,**self.dataloader_kwargs)
    
    def val_dataloader(self,shuf=False)-> DataLoader:
        return _get_dataloader(self.val_dataset,shuffle=shuf,**self.dataloader_kwargs)
    
    def test_dataloader(self,shuf=True)-> DataLoader:
        return _get_dataloader(self.test_dataset,shuffle=shuf,**self.dataloader_kwargs)

    def prepare_data(self):
        # Download the QM9 preprocessed data
        print("Data should be downloaded manually")

    def _collate(self, samples):
        mini_batch = self.mini_batch
        val_mini_batch=self.val_mini_batch
        points_occ_list=[]
        graphs=[]
        node_inds=[]
        query_inds=[]
        query_iou_inds=[]
        dict_lists={key:[] for key in samples[0].keys()}
        cnt=-1
        cnt_queries=-1
        cnt_iou_queries=-1
        for sample in samples:
            sample_keys=sample.keys()
            sample_graph=sample['graph']
            graphs.append(sample_graph)
            
            sample['occ_tuples'][0]=sample['occ_tuples'][0]+cnt_queries+1
            sample['occ_tuples'][1]=sample['occ_tuples'][1]+cnt+1
            if 'points_iou' in sample.keys():
                    sample['iou_tuples'][0]=sample['iou_tuples'][0]+cnt_iou_queries+1
                    sample['iou_tuples'][1]=sample['iou_tuples'][1]+cnt+1
            
            node_inds.append((cnt,cnt+sample_graph.num_nodes()))
            cnt+=sample_graph.num_nodes()
            
            query_inds.append((cnt_queries,cnt_queries+sample['points'].shape[0]//mini_batch))
            cnt_queries+=sample['points'].shape[0]//mini_batch
            
            if 'points_iou' in sample.keys():
                query_iou_inds.append((cnt_iou_queries, cnt_iou_queries+self.iou_number_points//val_mini_batch))
                cnt_iou_queries+=self.iou_number_points//val_mini_batch # changed from sample['points_iou'].shape[0]
            
            for key in sample_keys:
                if key not in ['graph']:
                    if torch.is_tensor(sample[key]) or (key in ['category','model']):
                        dict_lists[key].append(sample[key])
                    else:
                        dict_lists[key].append(torch.from_numpy(sample[key]))

         
        o_edges = torch.cat(dict_lists['occ_tuples'],dim=1) + torch.tensor([[cnt+1],[0]])
        

        batched_graph = dgl.batch(graphs)
        #edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
        batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
        # get node features
       
        #o_graph = [dgl.graph((o_edges[1],o_edges[0])) for i in range(mini_batch)]
        
        result_dict={}
 
        if 'points_iou' in samples[0].keys(): 
            iou_edges = torch.cat(dict_lists['iou_tuples'],dim=1) + torch.tensor([[cnt+1],[0]])
            #iou_graph = [dgl.graph((iou_edges[1],iou_edges[0]))  for i in range(val_mini_batch)]
            result_dict['iou_edges']=iou_edges

        #node_feats = {'1': batched_graph.ndata['pos'][:, None, :]}
        int_kmode=int(self.kmode)
        rel_difs=torch.mean(batched_graph.edata['rel_pos'].reshape(-1,int_kmode,3),1,keepdim=True)
        
        #node_feats = {'0': torch.ones(( batched_graph.ndata['pos'].shape[0],1,1))} 
        node_feats={'1': rel_difs}
        
        for key in dict_lists.keys():
            if len(dict_lists[key])>0 and (key not in ['category','model']):
                result_dict[key]=torch.stack(dict_lists[key],dim=0)
            elif key in ['category','model']:
                result_dict[key]=dict_lists[key]
                
        result_dict['input_graph']=batched_graph
        result_dict['node_feats']=node_feats

        result_dict['node_inds']=node_inds
        result_dict['query_inds']=query_inds
        result_dict['o_feats_list']=[]
        mini_size = result_dict['points'].shape[1] // mini_batch
        
        points_occ_mini =result_dict['points.occ']
        result_dict['points.occ']=[]
        mini_points=[]
        for i in range(mini_batch):
            mini_points.append(result_dict['points'][:,mini_size*i:mini_size*(i+1)])
            #o_graph[i].ndata['pos'] = torch.cat((result_dict['inputs'].reshape(-1,3),mini_points.reshape(-1,3)))
        
            #o_graph[i].edata['rel_pos'] = _get_relative_pos(o_graph[i]) 
            
            result_dict['o_feats_list'].append({'1': mini_points[-1].reshape(-1,1,3)})
            
            result_dict['points.occ'].append(points_occ_mini[:,mini_size*i:mini_size*(i+1)]) 
        #result_dict['o_graph'] = o_graph
        result_dict['points'] = mini_points

        if 'points_iou' in samples[0].keys():  
            if self.fixed_percentage_ones:
                percentage=0.3
                number_points=self.iou_number_points
                number_ones=int(0.3*number_points)
                number_zeros=number_points-number_ones
                _,top_indices=torch.topk(result_dict['points_iou.occ'],number_ones,dim=1)
                _,bot_indices=torch.topk(result_dict['points_iou.occ'],number_zeros,dim=1,largest=False)
                
                result_dict['points_iou']=torch.cat((torch.gather(result_dict['points_iou'],1,top_indices.unsqueeze(2).repeat(1,1,3)),torch.gather(result_dict['points_iou'],1,bot_indices.unsqueeze(2).repeat(1,1,3))),dim=1)
                result_dict['points_iou.occ']=torch.cat((torch.gather(result_dict['points_iou.occ'],1,top_indices),torch.gather(result_dict['points_iou.occ'],1,bot_indices)),dim=1)
            else:
                result_dict['points_iou']=result_dict['points_iou'][:,:self.iou_number_points]
                result_dict['points_iou.occ']=result_dict['points_iou.occ'][:,:self.iou_number_points]
            result_dict['val_mini_batch']=val_mini_batch
            '''
            iou_mini_size=result_dict['points_iou'].shape[1]//val_mini_batch
            result_dict['iou_feats']=[]
            points_iou_occ_mini = result_dict['points_iou.occ']
            result_dict['points_iou.occ']=[]
            for i in range(len(iou_graph)):
                miniou_points = result_dict['points_iou'][:,iou_mini_size*i:iou_mini_size*(i+1)]
                iou_graph[i].ndata['pos'] = torch.cat((result_dict['inputs'].reshape(-1,3),miniou_points.reshape(-1,3))) 
                iou_graph[i].edata['rel_pos'] = _get_relative_pos(iou_graph[i]) 
                result_dict['iou_feats'].append({'1': miniou_points.reshape(-1,1,3)})
                result_dict['points_iou.occ'].append(points_iou_occ_mini[:,iou_mini_size*i:iou_mini_size*(i+1)]) 
            result_dict['iou_graph']= iou_graph
            result_dict['iou_inds'] = query_iou_inds
            '''
        #result_dict.pop('points')
        #result_dict.pop('points_iou',0)
        #if 'points_iou' not in samples[0].keys():
        #    result_dict.pop('inputs') 
        return result_dict 
    @staticmethod
    def add_argparse_args(parent_parser):
        parser = parent_parser.add_argument_group("ShapeNet dataset")
        parser.add_argument('--task',type=str,choices=['recon'],default='recon',help='Task to train on')
        parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
                            help='Precompute bases at the beginning of the script during dataset initialization,'
                                 ' instead of computing them at the beginning of each forward pass.')
        return parent_parser

    def __repr__(self):
        return f'QM9({self.task})'

    '''
    class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
        """ Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """

        def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
            """
            :param bases_kwargs:  Arguments to feed the bases computation function
            :param batch_size:    Batch size to use when iterating over the dataset for computing bases
            """
            self.bases_kwargs = bases_kwargs
            self.batch_size = batch_size
            self.bases = None
            self.num_workers = num_workers
            super().__init__(*args, **kwargs)

        def load(self):
            super().load()
            # Iterate through the dataset and compute bases (pairwise only)
            # Potential improvement: use multi-GPU and gather
            dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
                                    collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
        bases = []
        for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
                             disable=get_local_rank() != 0):
            rel_pos = _get_relative_pos(graph)
            # Compute the bases with the GPU but convert the result to CPU to store in RAM
            bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
        self.bases = bases  # Assign at the end so that __getitem__ isn't confused

    def __getitem__(self, idx: int):
        graph, label = super().__getitem__(idx)

        if self.bases:
            bases_idx = idx // self.batch_size
            bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
            bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
            return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
                                  self.bases[bases_idx].items()}
        else:
            return graph, label
'''
