import os, sys
import torch.utils.data as data
from PIL import Image
from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, RandomHorizontalFlip, ToTensor, Normalize
import random
import numpy as np
import torch
import glob
import json
import pickle

class dataset_multi(data.Dataset):
  def __init__(self, opts):
    self.dataroot = opts.dataroot
    self.dataroot_2 = opts.dataroot_2
    self.num_domains = opts.num_domains
    self.input_dim = opts.input_dim

    if opts.MDMM_dataset_name == 'weatherImage':
      domains = ['cloudy','foggy','rain','snow','sunny']
      self.num_domains = len(domains)
      self.images = [None]*self.num_domains
      stats = ''
      for i in range(self.num_domains):
        # img_dir = os.path.join(self.dataroot, opts.phase + domains[i])
        img_dir = os.path.join(self.dataroot, domains[i])
        ilist = os.listdir(img_dir)
        self.images[i] = [os.path.join(img_dir, x) for x in ilist]
        stats += '{}: {}'.format(domains[i], len(self.images[i]))
        stats += ' images\n'
    
    elif opts.MDMM_dataset_name == 'ACDC':
      domains = ['fog','rain','snow','sunny']
      self.num_domains = len(domains)
      self.images = [None]*self.num_domains
      stats = ''
      for i in range(self.num_domains):
        self.images[i] = []
        for phase in ['train', 'test']:
          if domains[i] != 'sunny':
            img_dir = os.path.join(self.dataroot, domains[i], phase)
            dirlist = os.listdir(img_dir)
            for dir in dirlist:
              ilist = os.listdir(os.path.join(img_dir, dir))
              self.images[i] += [os.path.join(img_dir, dir, x) for x in ilist]
          else:
            for _domain in domains[:-1]:
              img_dir = os.path.join(self.dataroot, _domain, phase+'_ref')
              dirlist = os.listdir(img_dir)
              for dir in dirlist:
                ilist = os.listdir(os.path.join(img_dir, dir))
                self.images[i] += [os.path.join(img_dir, dir, x) for x in ilist]

        stats += '{}: {}'.format(domains[i], len(self.images[i]))
        stats += ' images\n'
    
    elif opts.MDMM_dataset_name == 'ithaca':
      if opts.sunnyNight:
        domains = ['night', 'sunny']
      else:
        domains = ['night', 'sunny', 'rain', 'cloud', 'snow']
      self.num_domains = len(domains)
      self.images = [None]*self.num_domains
      domainsIdxMap = {}
      for i in range(self.num_domains):
        self.images[i] = []
        domainsIdxMap[domains[i]] = i

     
      date2weather = {'01-16-2022': 'cloud', '01-17-2022': 'snow', '01-17-2022b': 'snow', '01-17-2022c': 'snow', '01-20-2022': 'sunny', '01-23-2022': 'night', '02-01-2022': 'night', '02-03-2022': 'snow','02-04-2022': 'snow', '02-04-2022b': 'snow', '02-11-2022': 'sunny', '02-17-2022': 'cloud', '02-17-2022b': 'night', '02-21-2022': 'sunny', '02-22-2022': 'rain', '02-24-2022': 'night', '02-25-2022': 'snow', '03-03-2022b': 'night', '08-20-2021': 'sunny', '10-04-2021': 'rain', '10-08-2021': 'sunny', '11-19-2021': 'cloud', '11-22-2021': 'cloud', '11-23-2021': 'cloud', '11-29-2021': 'snow', '11-30-2021': 'cloud', '12-01-2021': 'cloud', '12-02-2021': 'night', '12-03-2021': 'sunny', '12-06-2021': 'cloud', '12-07-2021': 'cloud', '12-08-2021': 'rain', '12-09-2021': 'night', '12-13-2021': 'sunny', '12-14-2021': 'sunny', '12-15-2021': 'rain', '12-16-2021': 'cloud', '12-18-2021': 'rain', '12-19-2021': 'rain', '12-19-2021b': 'cloud'}

      dates_whole_path = glob.glob(f'{self.dataroot}/*')
      dates = [ x.split('/')[-1].split('.')[0] for x in glob.glob(f'{self.dataroot}/*')]
      for i, date in enumerate(dates):
        if 'poses_bounds' in date or date=='02-17-2022b': continue
        weather_i = date2weather[date]
        if weather_i in domains:
          times = glob.glob(dates_whole_path[i]+'/cam3/*')
          times.sort()
          self.images[domainsIdxMap[weather_i]] += times[:-15]
      
      stats = ''
      for i in range(self.num_domains):
        stats += '{}: {}'.format(domains[i], len(self.images[i]))
        stats += ' images\n'
    
    elif opts.MDMM_dataset_name == 'ithaca_select':
      domains = ['night', 'sunny', 'rain', 'cloud', 'snow']
      self.num_domains = len(domains)
      self.images = [None]*self.num_domains
      domainsIdxMap = {}
      for i in range(self.num_domains):
        self.images[i] = []
        domainsIdxMap[domains[i]] = i
      
      with open('./TSIT/datasets/ithaca_select_scene.pickle', 'rb') as f:
        ithaca_scenes = pickle.load(f)
        
      ithaca_weather2info = {}
      for date, info in ithaca_scenes.items():
        weather = info['weather']
        if weather in ithaca_weather2info.keys():
          ithaca_weather2info[weather] += [{'date':date, 'first':info['first'], 'last':info['last']}]
        else:
          ithaca_weather2info[weather] = [{'date':date, 'first':info['first'], 'last':info['last']}]

      for mode in domains:
        for date_info in ithaca_weather2info[mode]:
          date = date_info['date']
          traj = [x.split('/')[-1].split('.')[0] for x in glob.glob(f'{self.dataroot}/{date}/cam0/*')]
          traj.sort(key=int)
          traj = np.array(traj)
          first_file_idx = np.argwhere(traj==date_info['first'])[0,0]
          last_file_idx = np.argwhere(traj==date_info['last'])[0,0]
          traj = traj[first_file_idx:last_file_idx+1][:8000]
          self.images[domainsIdxMap[mode]] += [os.path.join(self.dataroot, date, "cam0", f"{time}.png") for time in traj]
      stats = ''
      for i in range(self.num_domains):
        stats += '{}: {}'.format(domains[i], len(self.images[i]))
        stats += ' images\n'
    
    elif opts.MDMM_dataset_name == 'BDD100K':
      domains = ['night', 'cloudy', 'rainy', 'snowy', 'sunny']
      self.images = [[] for i in range(len(domains))]
      for i, domain in enumerate(domains):
        self.images[i] = []
        with open(os.path.join(self.dataroot, 'bdd100k_lists/sunny2diffweathers/%s_%s.txt' % (domain, 'train'))) as c_list:
          c_image_paths_read = c_list.read().splitlines()
          self.images[i] += [os.path.join(self.dataroot, p) for p in c_image_paths_read if p != '']

      stats = ''
      for i in range(self.num_domains):
        stats += '{}: {}'.format(domains[i], len(self.images[i]))
        stats += ' images\n'
    
    elif opts.MDMM_dataset_name == 'BDD100K_and_ithaca':
      ## bdd100k
      domains = ['night', 'sunny', 'rainy', 'cloudy', 'snowy']#, 'cloudy', 'rainy', 'snowy', 'sunny']
      self.images = [[] for i in range(len(domains))]
      for i, domain in enumerate(domains):
        self.images[i] = []
        with open(os.path.join(self.dataroot, 'bdd100k_lists/sunny2diffweathers/%s_%s.txt' % (domain, 'train'))) as c_list:
          c_image_paths_read = c_list.read().splitlines()
          self.images[i] += [os.path.join(self.dataroot, p) for p in c_image_paths_read if p != '']

      ## ithaca
      domains = ['night', 'sunny', 'rain', 'cloud', 'snow']
      self.num_domains = len(domains)
      domainsIdxMap = {}
      for i in range(self.num_domains):
        domainsIdxMap[domains[i]] = i

      with open(os.path.join("./data/ithaca365-label", '{}.json'.format('scene'))) as f:
        scene_table = json.load(f)

      with open(os.path.join("./data/ithaca365-label", '{}.json'.format('weather'))) as f:
        weather_table = json.load(f)
        weatherToken2weather = {}
        for wea in weather_table:
          weatherToken2weather[wea["token"]] = wea["description"]

      date2weather = {}
      for scene in scene_table:
        date2weather[scene["name"]] = weatherToken2weather[scene["weather_token"]]

      dates_whole_path = glob.glob(f'{self.dataroot_2}/*')
      dates = [ x.split('/')[-1].split('.')[0] for x in glob.glob(f'{self.dataroot_2}/*')]
      for i, date in enumerate(dates):
        if 'poses_bounds' in date: continue
        weather_i = date2weather[date]
        self.images[domainsIdxMap[weather_i]] += glob.glob(dates_whole_path[i]+'/cam3/*')[:-15]

      stats = ''
      for i in range(self.num_domains):
        stats += '{}: {}'.format(domains[i], len(self.images[i]))
        stats += ' images\n'

    elif opts.MDMM_dataset_name == 'waymo':
      domains = ['Day','Dawn/Dusk','Night']
      self.num_domains = len(domains)
      self.images = [None]*self.num_domains
      stats = ''
      for i in range(self.num_domains):
        img_dir = os.path.join(self.dataroot, domains[i])
        ilist = sorted(os.listdir(img_dir))[::5]
        self.images[i] = [os.path.join(img_dir, x) for x in ilist]
        stats += '{}: {}'.format(domains[i], len(self.images[i]))
        stats += ' images\n'

    print(stats)
    self.dataset_size = max([len(self.images[i]) for i in range(self.num_domains)])

    # setup image transformation
    if opts.phase != 'train' and opts.MDMM_dataset_name == 'ithaca':
      transforms = [Resize((512,768), Image.BICUBIC)]
    else:
      transforms = [Resize((opts.resize_size, opts.resize_size), Image.BICUBIC)]
    if opts.phase == 'train':
      transforms.append(RandomCrop(opts.crop_size))
    else:
      transforms.append(CenterCrop(opts.crop_size))
    if not opts.no_flip:
      transforms.append(RandomHorizontalFlip())
    transforms.append(ToTensor())
    transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
    self.transforms = Compose(transforms)

    return

  def __getitem__(self, index):
    cls = random.randint(0,self.num_domains-1)
    c_org = np.zeros((self.num_domains,))
    data = self.load_img(self.images[cls][random.randint(0, len(self.images[cls]) - 1)], self.input_dim)
    c_org[cls] = 1
    return data, torch.FloatTensor(c_org)
  
  def load_img(self, img_name, input_dim):
    img = Image.open(img_name).convert('RGB')
    img = self.transforms(img)
    if input_dim == 1:
      img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114
      img = img.unsqueeze(0)
    return img

  def __len__(self):
    return self.dataset_size


class dataset_single(data.Dataset):
  def __init__(self, opts, domain):
    self.opts = opts
    self.dataroot = opts.dataroot
    self.img = []
    self.frame_idx = None
    self.which_idx = None
    self.ref_img = None
    if domain == 'waymo':
      self.num_domains = 5
      domain = 1
      img_dir = './data/waymo/sunny/Dawn/Dusk'
      self.img = glob.glob(img_dir+'/*')
    
    elif domain == 'timeLapse':
      self.num_domains = 5
      domain = 1
      img_dir = './data/videos_h264_imgs/24_hour_Time_Lapse__4_NEW'
      self.img = glob.glob(img_dir+'/*')
      self.frame_idx = [x.split('/')[-1].split('.')[0] for x in self.img]

    elif opts.MDMM_dataset_name == 'ACDC':
      domains = ['fog','rain','snow','sunny']
      if domains[domain] != 'sunny':
        img_dir = os.path.join(self.dataroot, domains[domain], 'val')
        dirlist = os.listdir(img_dir)
        for dir in dirlist:
          ilist = os.listdir(os.path.join(img_dir, dir))
          self.img += [os.path.join(img_dir, dir, x) for x in ilist]
      
      else:
        for _domain in domains[:-1]:
          img_dir = os.path.join(self.dataroot, _domain, 'val_ref')
          dirlist = os.listdir(img_dir)
          for dir in dirlist:
            ilist = os.listdir(os.path.join(img_dir, dir))
            self.img += [os.path.join(img_dir, dir, x) for x in ilist]
      
    elif opts.MDMM_dataset_name == 'ithaca':
      domains = ['night', 'sunny', 'rain', 'cloud', 'snow']
      self.num_domains = len(domains)
      self.img = []
      domainsIdxMap = {}
      for i in range(self.num_domains):
        domainsIdxMap[domains[i]] = i

      with open(os.path.join("./data/ithaca365-label", '{}.json'.format('scene'))) as f:
        scene_table = json.load(f)

      with open(os.path.join("./data/ithaca365-label", '{}.json'.format('weather'))) as f:
        weather_table = json.load(f)
        weatherToken2weather = {}
        for wea in weather_table:
          weatherToken2weather[wea["token"]] = wea["description"]

      date2weather = {}
      for scene in scene_table:
        date2weather[scene["name"]] = weatherToken2weather[scene["weather_token"]]

      dates_whole_path = glob.glob(f'{self.dataroot}/*')
      dates = [ x.split('/')[-1].split('.')[0] for x in glob.glob(f'{self.dataroot}/*')]
      for i, date in enumerate(dates):
        if 'poses_bounds' in date: continue
        weather_i = date2weather[date]
        if domainsIdxMap[weather_i] == domain:
          self.img += glob.glob(dates_whole_path[i]+'/cam3/*')
      
      
    
    elif opts.MDMM_dataset_name == 'waymo':
      domains = ['Day','Dawn/Dusk','Night']
      self.num_domains = len(domains)
      img_dir = os.path.join(self.dataroot, domains[domain])
      ilist = sorted(os.listdir(img_dir))[::5]
      self.img = [os.path.join(img_dir, x) for x in ilist]

    elif opts.MDMM_dataset_name == 'ithacaConsistency':
      self.num_domains = 5
      self.img = []
      self.which_idx = []
      consistency_path = './logs/ithaca/val_all/geonerfMDMM_ver0-adain_content_level_styleTwoBranch-mse-woMSE1-zInputStyle_isInput-delta_t_1x1-t0Rec-sunnyNight-styleWaymo/to_calculate_consistency'
      for pair_idx in range(15):
        for nv_idx in range(2):
          self.img += [f'{consistency_path}/pair{pair_idx}/novel{nv_idx}/novelView.png']
          self.which_idx += [{'pair':pair_idx,'novel':nv_idx}]

    elif opts.MDMM_dataset_name == 'ttConsistency':
      self.num_domains = 5
      self.img = []
      self.which_idx = []
      for scene in ['Playground']:
        consistency_path = f'./logs/tt/val_all/geonerfMDMM_ver0-adain_content_level_styleTwoBranch-mse-woMSE1-zInputStyle_isInput-delta_t_1x1-t0Rec-sunnyNight-styleWaymo/{scene}/to_calculate_consistency-v1'
        for pair_idx in range(9):
          for nv_idx in range(2):
            self.img += [f'{consistency_path}/pair{pair_idx}/novel{nv_idx}/novelView.png']
            self.which_idx += [{'scene':scene,'pair':pair_idx,'novel':nv_idx}]

    elif opts.MDMM_dataset_name == 'ithacaFID':
      self.num_domains = 5
      self.img = []
      self.ref_img = []
      self.which_idx = []
      FID_path = './logs/ithaca/val_all/geonerfMDMM_ver0-adain_content_level_styleTwoBranch-mse-woMSE1-zInputStyle_isInput-delta_t_1x1-t0Rec-sunnyNight-styleWaymo/to_calculate_FID'
      for set_idx in range(100):
        self.img += [f'{FID_path}/{self.opts.time}/set{set_idx}/novelView.png']
        self.ref_img += [f'{FID_path}/{self.opts.time}/set{set_idx}/refImg.png']
        self.which_idx += [{'set':set_idx}]

    else:
      domains = [chr(i) for i in range(ord('A'),ord('Z')+1)]
      images = os.listdir(os.path.join(self.dataroot, opts.phase + domains[domain]))
      self.img = [os.path.join(self.dataroot, opts.phase + domains[domain], x) for x in images]
    
    self.size = len(self.img)
    self.input_dim = opts.input_dim

    self.c_org = np.zeros((self.num_domains,))
    self.c_org[domain] = 1
    # setup image transformation
    if 'Consistency' not in opts.MDMM_dataset_name and 'FID' not in opts.MDMM_dataset_name:
      if opts.MDMM_dataset_name == 'ithaca':
        transforms = [Resize((512,768), Image.BICUBIC)]
      else:
        transforms = [Resize((opts.resize_size, opts.resize_size), Image.BICUBIC)]
      if opts.MDMM_dataset_name != 'ithaca':
        transforms.append(CenterCrop(opts.crop_size))
    else:
      transforms = []
    transforms.append(ToTensor())
    transforms.append(Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]))
    self.transforms = Compose(transforms)
    return

  def __getitem__(self, index):
    data = self.load_img(self.img[index], self.input_dim, index=index)
    return data

  def load_img(self, img_name, input_dim, index=None):
    img = Image.open(img_name).convert('RGB')
    img = self.transforms(img)
    if input_dim == 1:
      img = img[0, ...] * 0.299 + img[1, ...] * 0.587 + img[2, ...] * 0.114
      img = img.unsqueeze(0)
    if self.ref_img != None:
      ref_img = Image.open(self.ref_img[index]).convert('RGB')
      ref_img = self.transforms(ref_img)
      if input_dim == 1:
        ref_img = ref_img[0, ...] * 0.299 + ref_img[1, ...] * 0.587 + ref_img[2, ...] * 0.114
        ref_img = ref_img.unsqueeze(0)
    if self.which_idx != None:
      if self.ref_img != None:
        return img, self.c_org, self.which_idx[index], ref_img
      else:
        return img, self.c_org, self.which_idx[index]
    elif self.frame_idx != None:
      return img, self.c_org, self.frame_idx[index]
    else:
      return img, self.c_org

  def __len__(self):
    return self.size
