"""
 Copyright 2023 [Anonymized]
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      https://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 """

import os
from time import time
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
import glob

import tensorflow as tf
# from google.cloud import storage
from tensorflow import Tensor
from tensorflow_similarity.losses import (
    CircleLoss,
    MultiSimilarityLoss,
    PNLoss,
    TripletLoss,
)


def read_tfrecord(
    tfrecord: Tensor
) -> Dict[str, Tensor]:
    """Read TF record files for RETVec training datasets.

    Args:
        tfrecord: TF record input.
    """
    base_features = {
        "index": tf.io.FixedLenFeature([], tf.int64),
    }
    record = []

    features = base_features.copy()
    for i in range(2):
        features[f'aug_token{i}'] = tf.io.FixedLenFeature([], tf.string)

    rec = tf.io.parse_single_example(tfrecord, features)

    record = {}
    prefixes = ['aug_token']

    for p in prefixes:
        tensors = [rec[p + str(i)] for i in range(2)]
        record[p] = tf.stack(tensors)
    for feature in base_features.keys():
        record[feature] = tf.stack([rec[feature]] * 2)

    return record


def Sampler(
    shards_list: List[str],
    batch_size: int = 32,
    process_record: Optional[Callable] = None,
    parallelism: int = tf.data.AUTOTUNE,
    file_parallelism: Optional[int] = 1,
    prefetch_size: Optional[int] = None,
    buffer_size: Optional[int] = None,
    compression_type: Optional[str] = "GZIP",
) -> tf.data.Dataset:
    total_shards = len(shards_list)
    print("found ", len(shards_list), 'shards', time())

    with tf.device('/CPU:0'):
        ds = tf.data.Dataset.from_tensor_slices(shards_list)
        ds = ds.shuffle(total_shards)

        ds = ds.interleave(
            lambda x: tf.data.TFRecordDataset(x, compression_type=compression_type),  # noqa
            block_length=1,  # problem here is that we have non flat record
            num_parallel_calls=file_parallelism,
            cycle_length=file_parallelism,
            deterministic=False)

        ds = ds.map(read_tfrecord, num_parallel_calls=parallelism)

        # ignore corrupted read errors, i.e. corrupted tfrecords
        ds = ds.ignore_errors()

        if buffer_size:
            ds = ds.shuffle(buffer_size)

        ds = ds.flat_map(lambda *x: tf.data.Dataset.from_tensor_slices(x))
        ds = ds.map(process_record, num_parallel_calls=parallelism)
        ds = ds.repeat()
        ds = ds.batch(batch_size)
        ds = ds.prefetch(prefetch_size)
        return ds


def process_tfrecord(e):
    x = {'token': e['aug_token']}
    y = {
        'similarity': e['index'],
        }
    return x, y


def get_dataset_samplers(
    train_path: str, test_path: str, config: Dict
) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    """Get train and test dataset samplers for training REW* models."""
    core_count = os.cpu_count()

    # Must have at least one CPU
    if not core_count:
        core_count = 1

    batch_size = config["train"]["batch_size"]
    buffer_size = config["train"]["shuffle_buffer"]

    train_shards = glob.glob(f"{train_path}/**/*.tfrecord", recursive=True)
    test_shards = glob.glob(f"{test_path}/**/*.tfrecord", recursive=True)

    train_ds = Sampler(
        train_shards,
        process_record=process_tfrecord,
        batch_size=batch_size,
        file_parallelism=core_count * 2,
        parallelism=core_count,
        buffer_size=buffer_size,
        prefetch_size=1000,
    )

    test_ds = Sampler(
        test_shards,
        process_record=process_tfrecord,
        file_parallelism=1,
        batch_size=batch_size,
        prefetch_size=1,
    )

    return train_ds, test_ds


def get_outputs_info(
    config: Dict,
) -> Tuple[List[Any], Set[str]]:
    """Returns the losses, and output names in the config."""
    loss = []
    outputs = set()

    sim_loss_config = config["outputs"]["similarity_loss"]
    sim_loss_type = sim_loss_config["type"]

    if sim_loss_type == "multisim":
        loss.append(
            MultiSimilarityLoss(
                distance="cosine",
                alpha=sim_loss_config.get("alpha", 2),
                beta=sim_loss_config.get("beta", 40),
                epsilon=sim_loss_config.get("epsilon", 0.1),
                lmda=sim_loss_config.get("lmda", 0.5),
            )
        )

    elif sim_loss_type == "circle":
        loss.append(
            CircleLoss(
                distance="cosine",
                gamma=sim_loss_config.get("gamma", 256),
                margin=sim_loss_config.get("margin", 0.0),
            )
        )

    elif sim_loss_type == "triplet":
        loss.append(TripletLoss(distance="cosine"))

    elif sim_loss_type == "pn":
        loss.append(PNLoss(distance="cosine"))

    outputs.add("similarity")

    return loss, outputs
