import os

import numpy as np
import pandas as pd
from pathlib import Path
import pytest
import torch

from ccvae.data.loaders import (
    inv_weighted_data_loaders,
    load_cached,
    load_celligner,
    make_strata,
    StratifiedSampler,
)
from ccvae.data.configs import CELLIGNER

@pytest.fixture
def celligner_dir(tmpdir_factory):
    data_dir = Path(tmpdir_factory.mktemp("celligner"))
    tumor_df = pd.DataFrame(
        data={
            "T1": [1] * 5,
            "T2": [2] * 5,
            "T3": [2] * 5,
            "Gene": ['ABC1', "ABC2", "ABC1", "ABC3", "ABC4"]
        },
    )
    tumor_df.to_csv(data_dir / CELLIGNER["tumor_file"], sep="\t", index=True)

    cell_df = pd.DataFrame(
        data={
            "C1": [1] * 5,
            "C2": [2] * 5,
            "C3": [2] * 5,
        },
        index=['ABC1 (ENS 1)', "ABC2 (ENS 2)", "ABC5 (ENS 5)", "ABC3 (ENS 3)", "ABC6 (ENS 6)"]
    ).T
    cell_df.to_csv(data_dir / CELLIGNER["cl_file"])

    info_df = pd.DataFrame(
        data={
            "disease": ["BRCA", "BRCA", "SRC", "SRC", "BRCA", "BRCA"],
            "subtype": ["basal", "luminal", "sarcoma", "sarcoma", "luminal", "basal"],
            "type": (["tumor"] * 3) + (["CL"] * 3)
        },
        index=["T1", "T2", "T3", 'C1', 'C2', 'C3']
    )
    info_df.to_csv(data_dir / CELLIGNER["info_file"], header=True)

    hgnc_df = pd.DataFrame(
        data={
            "symbol": ["ABC1", "ABC2", "ABC3", "ABC4", "ABC5", "ABC6"],
            "locus_group": ["protein-coding", "protein-coding", "protein-coding", "protein-coding", "non-coding RNA", "pseudogene"]
        }
    )
    hgnc_df.to_csv(data_dir / CELLIGNER["hgnc_file"], sep="\t")

    return data_dir

def test_load_celligner(celligner_dir):
    gex, metadata = load_celligner(celligner_dir)

    assert list(metadata.columns) == ["disease", "subtype", "type"]
    assert list(metadata["type"].unique()) == ["tumor", "CL"]
    assert list(metadata["disease"].unique()) == ["BRCA", "SRC"]

    dataset_dict = load_cached(load_celligner, celligner_dir, 'celligner')

    pd.testing.assert_frame_equal(metadata, dataset_dict['metadata_df'])
    np.testing.assert_array_equal(
        gex.to_numpy(), dataset_dict['features_tensor'].detach().numpy())

def test_make_strata():
    categories = pd.Series(
        [
            "Sarcoma",
            "Breast Cancer",
            "Breast Cancer",
            "Sarcoma",
            "Breast Cancer",
            "LMS",
            "CRC",
        ],
        index=range(1, 8),
    )

    strata = make_strata(categories, batch_size=2)
    assert {k: list(v) for k, v in strata.items()} == {
        "Sarcoma": [0, 3],
        "Breast Cancer": [1, 2, 4],
        "Other": [5, 6],
    }

    strata = make_strata(categories, batch_size=3)
    assert {k: list(v) for k, v in strata.items()} == {
        "Breast Cancer": [1, 2, 4],
        "Other": [0, 3, 5, 6],
    }

def test_stratified_sampler():
    strata = {
        "Sarcoma": [1, 4],
        "Breast Cancer": [2, 3, 5],
        "Other": [0, 6],
    }
    # Test with including any final, partial batches to ensure all items
    # are iterated over in each epoch. Otherwise the items included in
    # batches of strata that are not multiples of the batch size are
    # stochastic.
    sampler = StratifiedSampler(strata, 2, drop_last=False)
    all_samples = np.concatenate(list(sampler))
    # Each id from strata should be present in one loop through the iterator
    # of the sampler.
    all_samples.sort()
    expected_ids = np.concatenate(list(strata.values()))
    expected_ids.sort()
    np.testing.assert_array_equal(expected_ids, all_samples)


def test_inv_weighted_data_loaders(celligner_dir):
    gex, metadata = load_celligner(celligner_dir)
    data = torch.as_tensor(np.arange(len(metadata)).reshape((-1, 1)))
    strata = make_strata(metadata.disease, batch_size=2)
    train_loader, valid_loader = inv_weighted_data_loaders(tensor_dataset=torch.utils.data.TensorDataset(data),
                                                           metadata_df=metadata,
                                                           strata_column='disease',
                                                           strata=strata,
                                                           batch_size=3)
    train_epochs = [list(train_loader), list(train_loader)]
    valid_epochs = [list(valid_loader), list(valid_loader)]
    assert len(train_epochs[0]) == len(train_epochs[1])
    assert len(train_epochs[0]) == len(valid_epochs[0])
    assert len(train_epochs[0]) == len(valid_epochs[1])
    # The first epoch of the train_loader will match the valid_loader.
    # This is due to how the two loaders are seeded in the implementation rather than
    # necessity for training but it is asserted here to check that the loaders are
    # working as expected.
    for i, (train_batch, valid_batch) in enumerate(zip(train_epochs[0], valid_epochs[0])):
        assert len(train_batch) == 1
        assert len(valid_batch) == 1
        train_batch = train_batch[0]
        valid_batch = valid_batch[0]
        np.testing.assert_array_equal(train_batch.numpy(), valid_batch.numpy(), f'epoch 0, batch {i}')
    # The valid_laoder should return the same batches for each pass through an epoch.
    for i, (e0_batch, e1_batch) in enumerate(zip(valid_epochs[0], valid_epochs[1])):
        assert len(e0_batch) == 1
        assert len(e1_batch) == 1
        e0_batch = e0_batch[0]
        e1_batch = e1_batch[0]
        np.testing.assert_array_equal(e0_batch.numpy(), e1_batch.numpy(),
                                      f'valid batches {i} differ between epochs')
    # The train_loader should have shuffled the samples for the next epoch.
    for i, (train_batch, valid_batch) in enumerate(zip(train_epochs[1], valid_epochs[1])):
        assert len(train_batch) == 1
        assert len(valid_batch) == 1
        train_batch = train_batch[0]
        valid_batch = valid_batch[0]
        np.testing.assert_raises(AssertionError,
                                 np.testing.assert_array_equal,
                                 train_batch.numpy(),
                                 valid_batch.numpy(),
                                 f'epoch 0, batch {i}')
    # Every sample should be included in each epoch (drop_alst is false for the train_loader).
    # Assumes the data (one scalar per sample) are in numerical order.
    for epoch, batches in enumerate(train_epochs):
        # record the batch for easier debugging
        sample_indices = [idx.item()
                          for i, batch in enumerate(batches)
                          for idx in batch[0]]
        sample_indices.sort()
        np.testing.assert_array_equal(data.numpy().ravel(), sample_indices)
    for epoch, batches in enumerate(valid_epochs):
        # record the batch for easier debugging
        sample_indices = [idx.item()
                          for i, batch in enumerate(batches)
                          for idx in batch[0]]
        sample_indices.sort()
        np.testing.assert_array_equal(data.numpy().ravel(), sample_indices)

