import os
import numpy as np
import torch
import os

from comb_modules.dijkstra import dijkstra
from decorators import input_to_numpy
from utils import TrainingIterator
from torch.utils.data import Dataset, DataLoader

def generate_data_for_species_random(true_weights, n_task):
  true_weights2 = np.zeros((len(true_weights), n_task, true_weights.shape[-1], true_weights.shape[-1]))

  for i in range(len(true_weights)):
    new_lst_mat = []
    for s in range(n_task):
      #this is for generate same data for all species for sanity check 
      true_weights2[i][s] = true_weights[i][0]
      new_lst_mat.append(true_weights[i])

      # if s == 0:  
      #   true_weights2[i][s] = true_weights[i]
      #   new_lst_mat.append(true_weights[i])
      # else:
      #   flag = True
      #   # while flag:  
      #   new_weight = true_weights[i] * (s+1)# make the cost of the path double
      #     # new_weight = np.random.permutation(true_weights[i]) # randomly permute
      #     # check_sum = [(new_weight - a).any() for a in new_lst_mat]
      #     # if np.any(check_sum):
      #     #   new_lst_mat.append(new_weight)  
      #     #   flag=False 
      #   true_weights2[i][s] = new_weight
  return true_weights2

def generate_data_for_species(true_weights, n_task, lst_new_terrain_value=None):
  """
  generate data for species by swapping the cost of each species together based on train terrain value. 
  For example: human get 1 on grass, yet draft will cost 9 on grass. 
  """
  true_weights2 = np.zeros((len(true_weights), n_task, true_weights.shape[-1], true_weights.shape[-1]))
  terrain_value = np.unique(true_weights)
  print('terrain_value ', terrain_value)
  if lst_new_terrain_value is None: 
    lst_new_terrain_value = [terrain_value]
    for s in range(n_task):
      flag = True
      while flag: 
        new_terrain_value = np.random.permutation(terrain_value)
        check_sum = [(new_terrain_value - a).any() for a in lst_new_terrain_value]
        if np.any(check_sum):
          lst_new_terrain_value.append(new_terrain_value)
          flag=False
    print('Finished generate new terrain values')
  for s in range(n_task):
    if s == 0: 
      true_weights2[:, s, :, :] = true_weights
    cur_weight = true_weights.copy()
    for i in range(len(terrain_value)):
      cur_weight = np.where(true_weights==terrain_value[i], lst_new_terrain_value[s][i], cur_weight)
    true_weights2[:, s, :, :] = cur_weight
        
  return true_weights2, lst_new_terrain_value


def generate_data_for_species_noisy(true_weights, n_task, lst_new_terrain_value=None):
  """
  generate data for species by swapping the cost of each species together based on train terrain value. 
  For example: human get 1 on grass, yet draft will cost 9 on grass. 
  """
  true_weights2 = np.zeros((len(true_weights), n_task, true_weights.shape[-1], true_weights.shape[-1]))
  terrain_value = np.unique(true_weights)
  print('terrain_value ', terrain_value)
  if lst_new_terrain_value is None: 
    lst_new_terrain_value = [terrain_value]
    for s in range(n_task):
      flag = True
      while flag: 
        new_terrain_value = np.random.permutation(terrain_value)
        check_sum = [(new_terrain_value - a).any() for a in lst_new_terrain_value]
        if np.any(check_sum):
          lst_new_terrain_value.append(new_terrain_value)
          flag=False
    print('Finished generate new terrain values')
  for s in range(n_task):
    cur_weight = true_weights.copy()
    for i in range(len(terrain_value)):
      cur_weight = np.where(true_weights==terrain_value[i], lst_new_terrain_value[s][i], cur_weight)
      cur_weight =  np.round(cur_weight + np.random.normal(0,terrain_value.min()/4 , len(cur_weight)*12*12).reshape(-1,12, 12),3)
    true_weights2[:, s, :, :] = cur_weight
        
  return true_weights2, lst_new_terrain_value


def load_dataset_multi(data_dir, use_test_set, evaluate_with_extra, normalize, use_local_path,task_idx=None,n_task=None, normalize_path=False, num_sample=50): 
  train_prefix = "train"
  data_suffix = "maps"
  true_weights_suffix = "_9species"
  version_suffix = ''

  # true_weights_suffix = ""
  # version_suffix = '_ver3'
  val_prefix = ("test" if use_test_set else "val") + ("_extra" if evaluate_with_extra else "")
  train_data_path = os.path.join(data_dir, train_prefix + "_" + data_suffix + version_suffix +  ".npy")
  train_inputs = np.load(os.path.join(data_dir, train_prefix + "_" + data_suffix + "_orig.npy")).astype(np.float32)
  train_inputs = train_inputs.transpose(0, 3, 1, 2)  # channel first
  train_labels = np.load(os.path.join(data_dir, train_prefix +true_weights_suffix+ "_shortest_paths{}.npy".format(version_suffix)))[:, :n_task, :, :]
  train_true_weights = np.load(os.path.join(data_dir, train_prefix + true_weights_suffix+  "_vertex_weights{}.npy".format(version_suffix)))[:, :n_task, :, :]
  val_inputs = np.load(os.path.join(data_dir, val_prefix + "_" + data_suffix + version_suffix + ".npy")).astype(np.float32)
  val_inputs = val_inputs.transpose(0, 3, 1, 2)  # channel first
  val_labels = np.load(os.path.join(data_dir, val_prefix + true_weights_suffix+"_shortest_paths{}.npy".format(version_suffix)))[:, :n_task, :, :]
  val_true_weights = np.load(os.path.join(data_dir, val_prefix + true_weights_suffix+"_vertex_weights{}.npy".format(version_suffix)))[:, :n_task, :, :]
  if normalize:
    mean, std = (
      np.mean(train_inputs, axis=(0, 2, 3), keepdims=True),
      np.std(train_inputs, axis=(0, 2, 3), keepdims=True),
    )
    train_inputs -= mean
    train_inputs /= std
    val_inputs -= mean
    val_inputs /= std
  num_train_samples = len(train_inputs)
  num_sample_per_species = num_sample
  print('Total train sample:{} \t NUm sample per piece: {}'.format(num_train_samples, num_sample_per_species))
  batch_size = min(50,num_sample_per_species)

  lst_loader_train, lst_loader_val = [], []
  for i in range(n_task): 
    meta_file = os.path.join(os.path.join(data_dir, train_prefix +"sampleidx_species{}{}.npy".format(i, version_suffix)))
    idx = np.load(meta_file)

    if task_idx == -1: # same data for everything 
      i = 0 
      print('Sanity check: use same data')
    tmaps_train = train_inputs[idx[:num_sample_per_species]]
    cost_train = train_true_weights[idx[:num_sample_per_species], i, :, :]
    paths_train = train_labels[idx[:num_sample_per_species], i, :, :]
    dataset_train =mapDataset(tmaps_train, cost_train, paths_train)
    loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    lst_loader_train.append(loader_train)
    tmaps_val = val_inputs
    cost_val = val_true_weights[:, i, :, :]
    paths_val = val_labels[:, i, :, :]
    dataset_val =mapDataset(tmaps_val, cost_val, paths_val)
    loader_val  = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)
    lst_loader_val.append(loader_val)

  val_full_images = np.load(os.path.join(data_dir, val_prefix + "_maps.npy"))

  @input_to_numpy
  def denormalize(x):
    return (x * std) + mean
  metadata = {
    "input_image_size": val_full_images[0].shape[1],
    "output_features": val_true_weights[0].shape[-1] * val_true_weights[0].shape[-1],
    "num_channels": val_full_images[0].shape[-1],
    "denormalize": denormalize
  }
  return lst_loader_train, lst_loader_val, metadata



def load_dataset_multi_same_map(data_dir, use_test_set, evaluate_with_extra, normalize, use_local_path,task_idx=None,n_task=None, normalize_path=False, num_sample=500): 
  train_prefix = "train"
  data_suffix = "maps"
  true_weights_suffix = "_9species"
  version_suffix = ''

  # true_weights_suffix = ""
  # version_suffix = '_ver3'
  val_prefix = ("test" if use_test_set else "val") + ("_extra" if evaluate_with_extra else "")
  train_data_path = os.path.join(data_dir, train_prefix + "_" + data_suffix + version_suffix +  ".npy")
  train_inputs = np.load(os.path.join(data_dir, train_prefix + "_" + data_suffix + "_orig.npy")).astype(np.float32)
  train_inputs = train_inputs.transpose(0, 3, 1, 2)  # channel first
  train_labels = np.load(os.path.join(data_dir, train_prefix +true_weights_suffix+ "_shortest_paths{}.npy".format(version_suffix)))[:, :n_task, :, :]
  train_true_weights = np.load(os.path.join(data_dir, train_prefix + true_weights_suffix+  "_vertex_weights{}.npy".format(version_suffix)))[:, :n_task, :, :]
  val_inputs = np.load(os.path.join(data_dir, val_prefix + "_" + data_suffix + version_suffix + ".npy")).astype(np.float32)
  val_inputs = val_inputs.transpose(0, 3, 1, 2)  # channel first
  val_labels = np.load(os.path.join(data_dir, val_prefix + true_weights_suffix+"_shortest_paths{}.npy".format(version_suffix)))[:, :n_task, :, :]
  val_true_weights = np.load(os.path.join(data_dir, val_prefix + true_weights_suffix+"_vertex_weights{}.npy".format(version_suffix)))[:, :n_task, :, :]
  if normalize:
    mean, std = (
      np.mean(train_inputs, axis=(0, 2, 3), keepdims=True),
      np.std(train_inputs, axis=(0, 2, 3), keepdims=True),
    )
    train_inputs -= mean
    train_inputs /= std
    val_inputs -= mean
    val_inputs /= std
  num_train_samples = len(train_inputs)
  num_sample_per_species =num_sample
  print('Total train sample:{} \t NUm sample per piece: {}'.format(num_train_samples, num_sample_per_species))
  batch_size = min(50,num_sample_per_species)

  lst_loader_train, lst_loader_val = [], []
  for i in range(n_task): 
    meta_file = os.path.join(os.path.join(data_dir, train_prefix +"sampleidx_species{}{}.npy".format(i, version_suffix)))
    idx = np.load(meta_file)

    if task_idx == -1: # same data for everything 
      i = 0 
      print('Sanity check: use same data')
    tmaps_train = train_inputs[:num_sample_per_species]
    cost_train = train_true_weights[:num_sample_per_species, i, :, :]
    paths_train = train_labels[:num_sample_per_species, i, :, :]
    dataset_train =mapDataset(tmaps_train, cost_train, paths_train)
    loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
    lst_loader_train.append(loader_train)
    tmaps_val = val_inputs
    cost_val = val_true_weights[:, i, :, :]
    paths_val = val_labels[:, i, :, :]
    dataset_val =mapDataset(tmaps_val, cost_val, paths_val)
    loader_val  = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)
    lst_loader_val.append(loader_val)

  val_full_images = np.load(os.path.join(data_dir, val_prefix + "_maps.npy"))

  @input_to_numpy
  def denormalize(x):
    return (x * std) + mean
  metadata = {
    "input_image_size": val_full_images[0].shape[1],
    "output_features": val_true_weights[0].shape[-1] * val_true_weights[0].shape[-1],
    "num_channels": val_full_images[0].shape[-1],
    "denormalize": denormalize
  }
  return lst_loader_train, lst_loader_val, metadata



def compute_path(true_label,true_weights, n_task): 
  true_label2 = np.zeros((len(true_label), n_task, true_label.shape[-1], true_label.shape[-1]))
  for i in range(len(true_weights)): 
    for s in range(n_task): 
      # if s ==0: 
      #   true_label2[i, s, :, :] = true_label[i]
      # else: 
      true_label2[i, s, :, :] = dijkstra(true_weights[i][s]).shortest_path
  return true_label2

def load_dataset(data_dir, use_test_set, evaluate_with_extra, normalize, use_local_path,weight_suffix="",task_idx=0,n_task=None, normalize_path=False):
  train_prefix = "train"
  data_suffix = "maps"
  # if n_task: 
  #   true_weights_suffix = "_{}species".format(n_task)
  # else: 
  #   true_weights_suffix = ""
  # true_weights_suffix = ""
  true_weights_suffix =  "_{}species".format(weight_suffix) ## use precomputed 9 species, then just n_task  from those 9 species 

  val_prefix = ("test" if use_test_set else "val") + ("_extra" if evaluate_with_extra else "")
  train_data_path = os.path.join(data_dir, train_prefix + "_" + data_suffix + ".npy")

  if os.path.exists(train_data_path):
    train_inputs = np.load(os.path.join(data_dir, train_prefix + "_" + data_suffix + ".npy")).astype(np.float32)
    num_train_samples = len(train_inputs)
    num_sample_per_species = 666 

    if n_task ==1:
      #to train the best that single speices can do
      meta_file = os.path.join(os.path.join(data_dir, train_prefix +"sampleidx_species{}.npy".format(task_idx)))
      if os.path.exists(meta_file): 
        idx = np.load(meta_file)
      else: 
        idx = np.random.randint(num_train_samples, size=num_sample_per_species)
        np.save(meta_file,idx)
    else: 
      idx = np.arange(len(train_inputs))
    train_inputs = train_inputs.transpose(0, 3, 1, 2)[idx]  # channel first

  else:
    raise Exception(f"Cannot find {train_data_path}")
  if normalize:
    mean, std = (
      np.mean(train_inputs, axis=(0, 2, 3), keepdims=True),
      np.std(train_inputs, axis=(0, 2, 3), keepdims=True),
    )
    train_inputs -= mean
    train_inputs /= std

  val_inputs = np.load(os.path.join(data_dir, val_prefix + "_" + data_suffix + ".npy")).astype(np.float32)
  val_inputs = val_inputs.transpose(0, 3, 1, 2)  # channel first
  val_full_images = np.load(os.path.join(data_dir, val_prefix + "_maps.npy"))

  if normalize:
    val_inputs -= mean
    val_inputs /= std
  train_true_weights = np.load(os.path.join(data_dir, train_prefix + true_weights_suffix+  "_vertex_weights.npy"))[idx]
  val_true_weights = np.load(os.path.join(data_dir, val_prefix + true_weights_suffix+"_vertex_weights.npy"))

  if task_idx==-1: 
    train_true_weights = generate_data_for_species_random(train_true_weights[idx], n_task)
    val_true_weights = generate_data_for_species_random(val_true_weights, n_task)
    print('Use same data for sanity check')
  # print(val_true_weights.shape)
  # if n_task: 
  #   print('generate species data')
  #   train_true_weights = generate_data_for_species_random(train_true_weights, n_task)
  #   val_true_weights = generate_data_for_species_random(val_true_weights, n_task)

  # M: 05/23
  ######## load/precompute generated cost per speices
  if (n_task>1) & (not os.path.exists(os.path.join(os.path.join(data_dir, val_prefix + true_weights_suffix+"_vertex_weights.npy")))): 
    train_true_weights2, lst_new_terrain_value = generate_data_for_species(train_true_weights, n_task)
    val_true_weights2, _ = generate_data_for_species(val_true_weights, n_task, lst_new_terrain_value)
    true_weights_suffix = "_{}species".format(n_task)
    np.save(os.path.join(data_dir, val_prefix + true_weights_suffix+"_vertex_weights.npy"), val_true_weights2)
    np.save(os.path.join(data_dir, train_prefix + true_weights_suffix+"_vertex_weights.npy"),train_true_weights2)
    print('saved weight files!')
    train_true_weights = train_true_weights2[idx]
    val_true_weights = val_true_weights2
  ############ load/precompute optimal solution/shortest path based on the species cost. 
  #if not exist, compute
  train_labels = np.load(os.path.join(data_dir, train_prefix +true_weights_suffix+ "_shortest_paths.npy"))[idx]
  val_labels = np.load(os.path.join(data_dir, val_prefix + true_weights_suffix+"_shortest_paths.npy"))
  if (n_task>1) & (not os.path.exists(os.path.join(data_dir, train_prefix + true_weights_suffix + "_shortest_paths.npy"))): 
    train_labels = compute_path(train_labels,train_true_weights, n_task)
    val_labels = compute_path(val_labels, val_true_weights, n_task)
    np.save(os.path.join(data_dir, train_prefix+ true_weights_suffix + "_shortest_paths.npy"), train_labels)
    np.save(os.path.join(data_dir, val_prefix+ true_weights_suffix + "_shortest_paths.npy"), val_labels)
  # else load precomputed
  if n_task ==1:
    train_true_weights = train_true_weights[:, task_idx, :,:]
    val_true_weights = val_true_weights[:, task_idx, :, :]
    train_labels = train_labels[:, task_idx, :, :]
    val_labels = val_labels[:, task_idx, :, : ]
  else:
    if task_idx !=-1: 
      train_true_weights = train_true_weights[:, :n_task, :, :]
      val_true_weights = val_true_weights[:, :n_task, :, :]
      train_labels = train_labels[:, :n_task, :, :]
      val_labels = val_labels[:, :n_task, :, :]
    else: 
      train_labels = generate_data_for_species_random(train_labels[idx], n_task)
      val_labels = generate_data_for_species_random(val_labels, n_task)

  print('Train data size: ', train_labels.shape, train_inputs.shape, train_true_weights.shape)

  print('val data size: ', val_labels.shape, val_inputs.shape, val_true_weights.shape)

  train_iterator = TrainingIterator(dict(images=train_inputs, labels=train_labels, true_weights=train_true_weights))
  eval_iterator = TrainingIterator(
    dict(images=val_inputs, labels=val_labels, true_weights=val_true_weights)
  )

  @input_to_numpy
  def denormalize(x):
    return (x * std) + mean

  metadata = {
    "input_image_size": val_full_images[0].shape[1],
    "output_features": val_true_weights[0].shape[-1] * val_true_weights[0].shape[-1],
    "num_channels": val_full_images[0].shape[-1],
    "denormalize": denormalize
  }
  print(metadata)
  return train_iterator, eval_iterator, metadata



class mapDataset(Dataset):
    def __init__(self, tmaps, costs, paths):
        self.tmaps = tmaps
        self.costs = costs
        self.paths = paths
        # self.objs = (costs * paths).sum(axis=(1,2)).reshape(-1,1)
    def __len__(self):
        return len(self.costs)
    def __getitem__(self, ind):
        return (
            torch.FloatTensor(self.tmaps[ind]), # image
            torch.FloatTensor(self.costs[ind]),
            torch.FloatTensor(self.paths[ind]),
            # torch.FloatTensor(self.objs[ind]),
        )