# Copyright (C) 2021 Project Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import json
import os

from core.data import codenet_paths

from apache_beam.io.gcp import gcsio


def get_submission_ids(problem_id):
  gcs_data_root = codenet_paths.DATA_ROOT.replace('/mnt/', 'gs://')
  problem_dir = os.path.join(gcs_data_root, 'data', problem_id, 'Python')
  return [
      (problem_id, _get_submission_id(submission_path))
      for submission_path in gcsio.GcsIO().list_prefix(problem_dir).keys()
  ]


def _get_submission_id(submission_path):
  return submission_path.split('/')[-1].split('.')[0]


def get_problem_and_submission_ids(ids_file):
  gcsio_client = gcsio.GcsIO()
  with gcsio_client.open(ids_file, 'rb') as f:
    raw = f.read()
    text = raw.decode('utf-8')
    return json.loads(text)


def get_problem_and_submission_ids_from_tuples(ids_file):
  gcsio_client = gcsio.GcsIO()
  ids_list = []
  with gcsio_client.open(ids_file, 'rb') as f:
    for line in f.readlines():
      text = line.decode('utf-8')
      problem_id, submission_id = ast.literal_eval(text)
      ids_list.append((problem_id, submission_id))
  return ids_list


def get_split_ids_files(splits):
  if splits == 'all':
    split_ids_files = [
        # Anonymized.
    ]
  elif splits == 'sampled-test':
    split_ids_files = [
        # Anonymized.
    ]
  elif splits == 'small':
    split_ids_files = [
        # Anonymized.
    ]
  else:
    raise ValueError('Unexpected splits value', splits)
  return split_ids_files

  
def get_indexed_problem_and_submission_ids(split, ids_file):
  gcsio_client = gcsio.GcsIO()
  with gcsio_client.open(ids_file, 'rb') as f:
    raw = f.read()
    text = raw.decode('utf-8')
    ids_list = json.loads(text)
    indexed_ids_list = [
        ((split, index), ids)
        for index, ids in enumerate(ids_list)
    ]
    # e.g. [(('train', 0), ('p00001', 's1234567890')), ...]
    return indexed_ids_list
