from   abc import abstractmethod
import boto3
import io
import sys
from   collections import namedtuple, Counter, defaultdict
from   dataclasses import dataclass, field
from   joblib import Parallel, delayed
import logging
import multiprocessing as mp
import os
import pickle
import json
from   PIL import Image, ImageFile
import pandas as pd
import random
import re
from   time import perf_counter 
from   tqdm import tqdm 
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
from   torchvision import transforms
import torchvision.transforms.functional as TF
from   typing import Optional, List, Callable, Union, Dict, Any
import warnings
from .task_configs import task_parameters_taskonomy, task_parameters_omnidata
from .transforms import get_task_transform, default_loader
from .splits import get_splits

ImageFile.LOAD_TRUNCATED_IMAGES = True # TODO Test this


#######
#curricum
#######

       
class OmnidataDataset(data.Dataset):
    '''
    Taskonomy EPFL-S3 dataloader.
    Make sure the environment variables S3_ENDPOINT, S3_TASKONOMY_ACCESS,
    S3_TASKONOMY_KEY, and S3_TASKONOMY_BUCKET are set.

    Args:
        tasks: List of tasks
        split: One of {'train', 'val', 'test', 'all'}
        dataset_name: One of {'taskonomy_{variant}', 'replica', 'replica_gso', 'hypersim', 'blended_mvg', 'hm3d}
        image_size: Target image size
        max_images: Optional subset selection
        seed: Random seed for deterministic shuffling order
        filter_amount: How many "bad" images to remove. One of {'low', 'medium', 'high'}.
    '''

    def __init__(self,
                 tasks,
                 split='train',
                 dataset_name='replica',
                 image_size=512,
                 seed=0,
                 filter_amount='medium'):

        super(OmnidataDataset, self).__init__()
        self.tasks = tasks
        self.split = split
        self.dataset_name = dataset_name
        self.image_size=image_size
        self.seed = seed
        self.filter_amount = filter_amount
        
        # S3 bucket setup
        self.session = boto3.session.Session()
        if self.dataset_name.__contains__('taskonomy'):
            self.s3_client = self.session.client(
                service_name='s3',
                aws_access_key_id=os.environ.get('S3_TASKONOMY_ACCESS'),
                aws_secret_access_key=os.environ.get('S3_TASKONOMY_KEY'),
                endpoint_url=os.environ.get('S3_ENDPOINT')
            )
            self.bucket_name = os.environ.get('S3_TASKONOMY_BUCKET')
            print(os.environ.get('S3_TASKONOMY_ACCESS'))
            print(os.environ.get('S3_TASKONOMY_KEY'))
            print(os.environ.get('S3_ENDPOINT'))
        else:
            self.s3_client = self.session.client(
                service_name='s3',
                aws_access_key_id=os.environ.get('S3_OMNIDATA_ACCESS'),
                aws_secret_access_key=os.environ.get('S3_OMNIDATA_KEY'),
                endpoint_url=os.environ.get('S3_ENDPOINT')
            )
            self.bucket_name = os.environ.get('S3_OMNIDATA_BUCKET')
            print(os.environ.get('S3_OMNIDATA_ACCESS'))
            print(os.environ.get('S3_OMNIDATA_KEY'))
            print(os.environ.get('S3_ENDPOINT'))

        
        #  DataFrame containing information whether or not any file for any task exists
        self.df_meta = self.load_metadata()

        # Select splits based on selected size/variant
        splits = self.load_splits()
        if split == 'all':
            self.buildings = list(set(splits['train']) | set(splits['val']) | set(splits['test']))
        else:
            self.buildings = splits[split]
        self.buildings = sorted(self.buildings)
        self.df_meta = self.df_meta.loc[self.buildings]   

        # Filter bad images
        print("___before filtering : ", len(self.df_meta))
        self.filter_dataset()
        print("___after filtering : ", len(self.df_meta))
        
        self.df_meta = self.df_meta[tasks] # Select tasks of interest
        
        if self.split == "train":
            # Only select rows where we have all the tasks
            self.df_meta = self.df_meta[self.df_meta.all(axis=1)]
            # Random shuffle
            self.df_meta = self.df_meta.sample(frac=1, random_state=seed) 
            print(f'Using {len(self.df_meta)} images from {dataset_name} in split {self.split}.')
        else: 
        
            # filtered metadata
            self.df_meta_filtered = self.load_filtered_metadata()
            print(f'Using {len(self.df_meta_filtered)} images from {dataset_name} in split {self.split}.')

        
        
        
    @abstractmethod
    def load_metadata(self):
        raise NotImplementedError (f'Cannot load metadata dataframe')   
    
    @abstractmethod
    def load_splits(self):
        raise NotImplementedError (f'Cannot load dataset splits')   

    @abstractmethod
    def filter_dataset(self):
        raise NotImplementedError (f'Filtering dataset not implemented')

    @abstractmethod
    def get_transform(self, file, task, task_id=0):
        raise NotImplementedError (f'Dataset transforms not implemented')   

    def __len__(self):
        if self.split == "train":
            return len(self.df_meta)
        else:
            return len(self.df_meta_filtered)
    
    def load_filtered_metadata(self):
        bpv_list = []
        pinfo = pd.read_pickle(os.path.join(
            os.path.dirname(__file__), 'component_datasets', self.dataset_name, f'pinfo_{self.split}.pkl'))
        
#         camera_rolls = pinfo["camera_rotation_final"][1]*180.0 / 3.14
        camera_pitches = pinfo["camera_rotation_final"][0]*180.0 / 3.14
        indices = [i for i in range(len(camera_pitches)) if (camera_pitches[i] > 0 and camera_pitches[i] < 55.)]
        
        for i in indices:
            bpv_list.append((pinfo["building"][i], pinfo["point_uuid"][i], pinfo["view_id"][i].item()))
           
        random.Random(4).shuffle(bpv_list)
        return bpv_list
        
    
    def process_point_info(self, pinfo):
#         ['camera_distance', 'obliqueness_angle', 'point_normal', 'view_id', 'camera_uuid', 'camera_location', 'point_uuid', 'camera_rotation_final', 'field_of_view_rads', 'point_pitch', 'point_location']
        pinfo_tensor = []
        for k in ['camera_distance', 'obliqueness_angle', 'point_normal', 'camera_location', 'camera_rotation_final', 'field_of_view_rads', 'point_pitch', 'point_location']:
            v = pinfo[k]
            if v == 'field_of_view_rads': print(k, flush=True)
            if isinstance(v, list):
                pinfo_tensor += v
            else:
                pinfo_tensor.append(float(v))
        
        pt = torch.tensor(pinfo_tensor) #.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
        return pt

    def __getitem__(self, index):

        # building / point / view are encoded in dataframe index
        

        if self.split == "train":
            building, point, view = self.df_meta.iloc[index].name
        else:
            building, point, view = self.df_meta_filtered[index]

        # TODO: Remove this try/except after we made sure there are no bad/missing images!
        # Very slow if it fails.
        try:

            result = {}
            new_tasks = self.tasks + ["point_info"]
            for task_id, task in enumerate(new_tasks):
                # Load from S3 bucket
                ext = task_parameters_taskonomy[task]['ext']
                if self.dataset_name.__contains__('taskonomy'):
                    domain_id = task_parameters_taskonomy[task]['domain_id']
                    key = f'taskonomy_imgs/{task}/{building}/point_{point}_view_{view}_domain_{domain_id}.{ext}'
                else:
                    domain_id = task_parameters_omnidata[task]['domain_id']
                    key = f'omnidata_imgs/{task}/{self.dataset_name}/{building}/point_{point}_view_{view}_domain_{domain_id}.{ext}'

                if self.dataset_name in ['hypersim'] and task == 'reshading':
                    result[task] = torch.zeros((1, self.image_size, self.image_size))
                    continue

                obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)['Body'].read()

                # Convert bytes to image / json / array / etc...
                if ext == 'png':
                    file = Image.open(io.BytesIO(obj))
                elif ext == 'json':
                    file = json.load(io.BytesIO(obj))
                    if task == 'point_info':
                        file['building'] = building
                        file.pop('nonfixated_points_in_view')
                elif ext == 'npy':
                    file = np.frombuffer(obj)
                else:
                    raise NotImplementedError(f'Loading extension {ext} not yet implemented')

                # Perform transformations
                if task != 'point_info':
                    file = self.get_transform(file, task=task, task_id=task_id)
                else:
                    file = self.process_point_info(file)

                result[task] = file

            return result

        except:
            # In case image was faulty or not uploaded yet, try with random other image

            print(key, " not found!!!!")
            return self[np.random.randint(len(self))]
    
    
    
class OmnidataDataset2(data.Dataset):
    '''
    Taskonomy EPFL-S3 dataloader.
    Make sure the environment variables S3_ENDPOINT, S3_TASKONOMY_ACCESS,
    S3_TASKONOMY_KEY, and S3_TASKONOMY_BUCKET are set.

    Args:
        tasks: List of tasks
        split: One of {'train', 'val', 'test', 'all'}
        dataset_name: One of {'taskonomy_{variant}', 'replica', 'replica_gso', 'hypersim', 'blended_mvg', 'hm3d}
        image_size: Target image size
        max_images: Optional subset selection
        seed: Random seed for deterministic shuffling order
        filter_amount: How many "bad" images to remove. One of {'low', 'medium', 'high'}.
    '''
    
    def load_images_from_file(self):
        file = open(os.path.join(f'/scratch/taming-transformers/{self.split}.txt'), 'r')
        Lines = file.readlines()

        images = []
        for i, line in enumerate(Lines):
            path = line.strip()
            if self.split != 'train' or path not in self.filtered_imgs:
                images.append(path)
            else:
                print("filtered")
        return images

    
    def load_images1(self, task):
        #  folders are building names. 
        images = []
        task_path = os.path.join(self.data_path, task, self.dataset)
        for building in os.listdir(task_path):
            for fname in sorted(os.listdir(os.path.join(task_path, building))):
                img_path = os.path.join(task_path, building, fname)
                images.append(img_path)

        return images
    
    def load_images(self, task):
        images = []
        for setting in os.listdir(self.data_path):
            for img in os.listdir(os.path.join(self.data_path, setting)):
                if img.endswith(f'{task}.png'):
                    img_path = os.path.join(self.data_path, setting, img)
                    images.append(img_path)

        return images
    
    
    def load_images_viewpoint_density(self, task, view_per_point=1, total_views=None):
        #  folders are building names. 
        bpv = defaultdict(lambda: defaultdict(list))
        images = []
        task_path = os.path.join(self.data_path, task, self.dataset)
        for building in os.listdir(task_path):
            for fname in sorted(os.listdir(os.path.join(task_path, building))):
                point = fname.split('_')[1]
                if len(bpv[building][point]) >= view_per_point: continue
                bpv[building][point].append(fname)
                
        n_points = 0
                
        if total_views is None: # sparse
            for b in bpv.keys():
                for p in bpv[b].keys():
                    n_points += 1
                    for i in range(view_per_point):
                        random.shuffle(bpv[b][p])
                        fname = bpv[b][p].pop()
                        img_path = os.path.join(task_path, b, fname)
                        images.append(img_path)
                        
        else: # dense
            while len(images) < total_views:
                while True:
                    b = random.choice(list(bpv.keys()))
                    p = random.choice(list(bpv[b].keys()))
                    if bpv[b][p]: break
                views = bpv[b].pop(p)
                n_points += 1
                for fname in views:
                    img_path = os.path.join(task_path, b, fname) 
                    images.append(img_path)
                    print(views, fname, img_path)
                if len(bpv[b].keys()) == 0:
                    bpv.pop(b)
                
            
        n_imgs = len(images)
        print(f"!!!!!! {n_imgs} images - {n_points} points ")

        return images, n_points, n_imgs

                            

    def __init__(self,
                 tasks,
                 dataset='omnidata',
                 split='train',
                 image_size=256,
                 data_path='/scratch/ldm_generated_data/',
                 seed=1):

        super(OmnidataDataset2, self).__init__()
        self.tasks = tasks
        self.data_path = data_path
        self.split = split
        self.dataset = dataset
        self.image_size=image_size
        self.seed = seed
    
        
#         self.dataset_urls, n_points, n_imgs = self.load_images_viewpoint_density('rgb', view_per_point=1)
#         self.dataset_urls, n_points, n_imgs = self.load_images_viewpoint_density('rgb', view_per_point=1, total_views=len(self.dataset_urls))

#         filter_pkl_path = '/dev/shm/datasets2/filtered_imgs.pkl'
#         with open(filter_pkl_path, 'rb') as f:
#             self.filtered_imgs = pickle.load(f)
        self.dataset_urls = self.load_images(task='rgb')
        
        
#         saving_path = f'/scratch/curriculum/viewpoint_density_imgs/'
#         if not os.path.exists(saving_path): os.makedirs(saving_path)
#         pkl_path = os.path.join(saving_path, f'{n_imgs}_imgs_{n_points}_points.pkl')
#         with open(pkl_path, 'wb') as f:
#             pickle.dump(self.dataset_urls, f)
        
        
        random.shuffle(self.dataset_urls)

        print(f'Using {len(self.dataset_urls)} images from {self.dataset} dataset.')
        

    def get_transform(self, file, task, task_id=0):
        task_transform = get_task_transform(task=task, image_size=self.image_size)
        file = task_transform(file)
        return file
        

    def __len__(self):
        return len(self.dataset_urls)

    def __getitem__(self, index):


        rgb_url = self.dataset_urls[index]

        try:

            result = {}
            for task_id, task in enumerate(self.tasks):
                task_url = rgb_url.replace('rgb', task)

                # use generated data from VQ-GAN for training
    #             if self.split == "train" and task in ["rgb", "normal"]: 
    #                 task_url = task_url.replace('/datasets/', '/datasets2/')

                if task == 'point_info':
                    task_url = task_url.replace('point_info.png', 'fixatedpose.json')


                file = default_loader(task_url)

                # Perform transformations
                if task != 'point_info':
                    file = self.get_transform(file, task=task, task_id=task_id)

                result[task] = file


            return result

        except:
            # In case image was faulty or not uploaded yet, try with random other image

            print(rgb_url, " not found!!!!")
            return self[np.random.randint(len(self))]
