import torch
from options import TestOptions
from datasets import dataset_single
from model import MD_multi
from saver import save_imgs, save_concat_imgs
import os
import math

def main():
  # parse options
  parser = TestOptions()
  opts = parser.parse()
  if 'ACDC' in opts.resume:
    train_domains = ['fog','rain','snow','sunny']
  elif 'weather' in opts.resume:
    train_domains = ['cloudy','foggy','rain','snow','sunny']
  elif 'ithaca' in opts.resume:
    train_domains = ['night', 'sunny', 'rain', 'cloud', 'snow']
  elif 'waymo' in opts.resume:
    train_domains = ['Day','Dawn/Dusk','Night']
  else:
    train_domains = [chr(i) for i in range(ord('A'),ord('Z')+1)]
  
  if 'ACDC' in opts.MDMM_dataset_name:
    test_domains = ['fog','rain','snow','sunny']
  elif 'weather' in opts.MDMM_dataset_name:
    test_domains = ['cloudy','foggy','rain','snow','sunny']
  elif 'ithaca' in opts.MDMM_dataset_name:
    test_domains = ['night', 'sunny', 'rain', 'cloud', 'snow']
  elif 'waymo' in opts.MDMM_dataset_name:
    test_domains = ['Day','Dawn/Dusk','Night']
  else:
    test_domains = [chr(i) for i in range(ord('A'),ord('Z')+1)]
  
  if 'Consistency' in opts.MDMM_dataset_name:
    test_domains = ['sunny']

  # data loader
  print('\n--- load dataset ---')
  datasets = [None]*len(test_domains)
  loaders = [None]*len(test_domains)
  for i in range(len(test_domains)):
    datasets[i] = dataset_single(opts, i)
    loaders[i] = torch.utils.data.DataLoader(datasets[i], batch_size=1, num_workers=opts.nThreads)

  # model
  print('\n--- load model ---')
  model = MD_multi(opts)
  model.setgpu(opts.gpu)
  model.resume(opts.resume, train=False)
  model.eval()

  # directory
  result_dir = os.path.join(opts.result_dir, opts.name)
  if not os.path.exists(result_dir):
    os.mkdir(result_dir)

  # test
  # print('\n--- testing ---')
  # for d in range(len(test_domains)):
  #   with torch.no_grad():
  #     z_random = model.get_z_random(1, model.nz, 'gauss')
  #   for idx, data in enumerate(loaders[d]):
  #     #break
  #     img, c_org = data
  #     print('{}/{}'.format(idx, len(loaders[d])))
  #     if idx >= 5: break
  #     img, c_org = img.cuda(opts.gpu), c_org.cuda(opts.gpu)
  #     imgs = [img]
  #     names = ['input']
  #     for idx2 in range(opts.num):
  #       with torch.no_grad():
  #         imgs_ = model.test_forward_random(img, z_random=z_random)
  #       for i in range(len(train_domains)):
  #         imgs.append(imgs_[i])
  #         names.append('output{}_{}_{}'.format(test_domains[d], train_domains[i], idx2))
  #     save_imgs(imgs, names, os.path.join(result_dir, '{}_{}'.format(test_domains[d], idx)))
  #     # save_concat_imgs(imgs, 'output{}_{}'.format(domains[d], idx), result_dir)
  
  # # test: phi adain
  # for d in range(len(test_domains)):
  #   with torch.no_grad():
  #     z_random = model.get_z_random(1, model.nz, 'gauss')
  #   for idx, data in enumerate(loaders[d]):
  #     #break
  #     img, c_org = data
  #     print('{}/{}'.format(idx, len(loaders[d])))
  #     if idx >= 5: break
  #     img, c_org = img.cuda(opts.gpu), c_org.cuda(opts.gpu)
  #     imgs = [img]
  #     names = ['input']
  #     for idx2 in range(8):
  #       with torch.no_grad():
  #         imgs_ = model.test_forward_random(img, z_random=z_random, phi=((idx2/8)*(2*math.pi)))
  #         imgs.append(imgs_[0])
  #         names.append('output{}_{}'.format(test_domains[d], idx2))
  #     save_imgs(imgs, names, os.path.join(result_dir, '{}_{}'.format(test_domains[d], idx)))
  #     # save_concat_imgs(imgs, 'output{}_{}'.format(domains[d], idx), result_dir)

  # test transfer (ithaca_consistency(src)+timelapse(ref))
  if opts.MDMM_dataset_name == 'ithacaConsistency':
    result_dir = './geonerfMDMM_ver0-adain_content_level_styleTwoBranch-mse-woMSE1-zInputStyle_isInput-delta_t_1x1-t0Rec-sunnyNight-styleWaymo'
    datasets_ref = dataset_single(opts, 'timeLapse')
    loader_ref = torch.utils.data.DataLoader(datasets_ref, batch_size=1, num_workers=opts.nThreads)
    with torch.no_grad():
      for d in range(len(test_domains)):
        for idx, data in enumerate(loaders[d]):
          print('{}/{}'.format(idx, len(loaders[d])))
          img, c_org, which_idx = data
          pair_idx, novel_idx = which_idx['pair'], which_idx['novel']
          img, c_org = img.cuda(opts.gpu), c_org.cuda(opts.gpu)
          
          imgs = []
          names = []
          for idx2, ref_data in enumerate(loader_ref):
            ref_img, ref_domain, ref_frame_idx = ref_data
            ref_img = ref_img.cuda()

            domain_vec = torch.zeros((1,opts.num_domains)).cuda()
            domain_vec[:,1] = 1 # sunny
            img_ = model.test_forward_transfer(img, ref_img, domain_vec)
            imgs.append(img_)
            names.append(f'novelView_phi{ref_frame_idx[0]}')
          save_imgs(imgs, names, os.path.join(result_dir, 'ithaca', f'pair{pair_idx.item()}', f'novel{novel_idx.item()}'))
  elif opts.MDMM_dataset_name == 'ttConsistency':
    result_dir = './geonerfMDMM_ver0-adain_content_level_styleTwoBranch-mse-woMSE1-zInputStyle_isInput-delta_t_1x1-t0Rec-sunnyNight-styleWaymo'
    datasets_ref = dataset_single(opts, 'timeLapse')
    loader_ref = torch.utils.data.DataLoader(datasets_ref, batch_size=1, num_workers=opts.nThreads)
    with torch.no_grad():
      for d in range(len(test_domains)):
        for idx, data in enumerate(loaders[d]):
          print('{}/{}'.format(idx, len(loaders[d])))
          img, c_org, which_idx = data
          pair_idx, novel_idx, scene = which_idx['pair'], which_idx['novel'], which_idx['scene'][0]
          img, c_org = img.cuda(opts.gpu), c_org.cuda(opts.gpu)
          
          imgs = []
          names = []
          for idx2, ref_data in enumerate(loader_ref):
            ref_img, ref_domain, ref_frame_idx = ref_data
            ref_img = ref_img.cuda()

            domain_vec = torch.zeros((1,opts.num_domains)).cuda()
            domain_vec[:,1] = 1 # sunny
            img_ = model.test_forward_transfer(img, ref_img, domain_vec)
            imgs.append(img_)
            names.append(f'novelView_phi{ref_frame_idx[0]}')
          save_imgs(imgs, names, os.path.join(result_dir, 'tt', f'{scene}-v1', f'pair{pair_idx.item()}', f'novel{novel_idx.item()}'))
  
  elif opts.MDMM_dataset_name == 'ithacaFID':
    result_dir = './geonerfMDMM_ver0-adain_content_level_styleTwoBranch-mse-woMSE1-zInputStyle_isInput-delta_t_1x1-t0Rec-sunnyNight-styleWaymo'
    with torch.no_grad():
      for d in range(len(test_domains)):
        for idx, data in enumerate(loaders[d]):
          print('{}/{}'.format(idx, len(loaders[d])))
          img, c_org, which_idx, ref_img = data
          set_idx = which_idx['set']
          img, c_org = img.cuda(opts.gpu), c_org.cuda(opts.gpu)
          ref_img = ref_img.cuda(opts.gpu)
          
          imgs = []
          names = []

          domain_vec = torch.zeros((1,opts.num_domains)).cuda()
          domain_vec[:,1] = 1 # sunny
          img_ = model.test_forward_transfer(img, ref_img, domain_vec)
          imgs.append(img_)
          names.append(f'novelView_t')
          save_imgs(imgs, names, os.path.join(result_dir, 'to_calculate_FID','ithaca', opts.time, f'set{set_idx.item()}'))

  # # test transfer (ithaca(src)+waymo(ref))
  # datasets_ref = dataset_single(opts, 'waymo')
  # # datasets_ref = dataset_single(opts, 1)
  # loader_ref = torch.utils.data.DataLoader(datasets_ref, batch_size=1, num_workers=opts.nThreads)
  # for d in range(len(test_domains)):
  #   for idx, data in enumerate(loaders[d]):
  #     #break
  #     img, c_org = data
  #     print('{}/{}'.format(idx, len(loaders[d])))
  #     if idx >= 5: break
  #     img, c_org = img.cuda(opts.gpu), c_org.cuda(opts.gpu)
  #     imgs = [img]
  #     names = ['input']
  #     for idx2 in range(opts.num):
  #       for ref_data in loader_ref:
  #         with torch.no_grad():
  #           ref_img, ref_domain = ref_data
  #           ref_img = ref_img.cuda()
  #           imgs.append(ref_img)
  #           names.append('ref')

  #           domain_vec = torch.zeros((1,opts.num_domains)).cuda()
  #           domain_vec[:,1] = 1 # sunny
  #           img_ = model.test_forward_transfer(img, ref_img, domain_vec)
  #           imgs.append(img_)
  #           names.append('output{}_{}_{}'.format(test_domains[d], 'waymo', idx2))
  #           break
  #     save_imgs(imgs, names, os.path.join(result_dir, '{}_{}'.format(test_domains[d], idx)))
  #     # save_concat_imgs(imgs, 'output{}_{}'.format(domains[d], idx), result_dir)
      
  return

if __name__ == '__main__':
  main()
