import os
import pytest
import util_test
import collections
import tarfile
import io
from PIL import Image

from training.data import get_wds_dataset
from training.params import parse_args
from training.main import random_seed

TRAIN_NUM_SAMPLES = 10_000
RTOL = 0.2

# NOTE: we use two test tar files, which are created on the fly and saved to data/input.
# 000.tar has 10 samples, and the captions are 000_0, 000_1, ..., 000_9
# 001.tar has 5 samples, and the captions are 001_0, 001_1, ..., 001_4
def build_inputs(test_name):
    base_input_dir, _ = util_test.get_data_dirs()
    input_dir = os.path.join(base_input_dir, test_name)
    os.makedirs(input_dir, exist_ok=True)
    
    def save_tar(idx, num_samples):
        filename = os.path.join(input_dir, f'test_data_{idx:03d}.tar')
        tar = tarfile.open(filename, 'w')
        
        for sample_idx in range(num_samples):
            # Image
            image = Image.new('RGB', (32, 32))
            info = tarfile.TarInfo(f'{sample_idx}.png')
            bio = io.BytesIO()
            image.save(bio, format='png')
            size = bio.tell()
            bio.seek(0)
            info.size = size
            tar.addfile(info, bio)
            
            # Caption
            info = tarfile.TarInfo(f'{sample_idx}.txt')
            bio = io.BytesIO()
            bio.write(f'{idx:03d}_{sample_idx}'.encode('utf-8'))
            size = bio.tell()
            bio.seek(0)
            info.size = size
            tar.addfile(info, bio)
        
        tar.close()          

    save_tar(0, 10)
    save_tar(1, 5)

    return input_dir


def build_params(input_shards, seed=0):
    args = parse_args([])
    args.train_data = input_shards
    args.train_num_samples = TRAIN_NUM_SAMPLES
    args.dataset_resampled = True
    args.seed = seed
    args.workers = 1
    args.world_size = 1
    args.batch_size = 1
    random_seed(seed)

    preprocess_img = lambda x: x
    tokenizer = lambda x: [x.strip()]

    return args, preprocess_img, tokenizer


def get_dataloader(input_shards):
    args, preprocess_img, tokenizer = build_params(input_shards)
    dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
    dataloader = dataset.dataloader
    return dataloader


def test_single_source():
    """Test webdataset with a single tar file."""
    input_dir = build_inputs('single_source')    
    input_shards = os.path.join(input_dir, 'test_data_000.tar')
    dataloader = get_dataloader(input_shards)
    
    counts = collections.defaultdict(int)
    for sample in dataloader:
        txts = sample[1]
        for txt in txts:
            counts[txt] += 1
    
    for key, count in counts.items():
        assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL)


def test_two_sources():
    """Test webdataset with a single two tar files."""
    input_dir = build_inputs('two_sources')
    input_shards = os.path.join(input_dir, 'test_data_{000..001}.tar')
    dataloader = get_dataloader(input_shards)

    counts = collections.defaultdict(int)
    for sample in dataloader:
        txts = sample[1]
        for txt in txts:
            counts[txt] += 1
    
    for key, count in counts.items():
        assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}'


def test_two_sources_same_weights():
    """Test webdataset with a two tar files, using --train-data-weights=1::1."""
    input_dir = build_inputs('two_sources_same_weights')
    input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}"
    args, preprocess_img, tokenizer = build_params(input_shards)
    args.train_data_upsampling_factors = '1::1'
    dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
    dataloader = dataset.dataloader

    counts = collections.defaultdict(int)
    for sample in dataloader:
        txts = sample[1]
        for txt in txts:
            counts[txt] += 1
    
    for key, count in counts.items():
        assert count == pytest.approx(TRAIN_NUM_SAMPLES / 15, RTOL), f'{key}, {count}'

def test_two_sources_with_upsampling():
    """Test webdataset with a two tar files with upsampling."""
    input_dir = build_inputs('two_sources_with_upsampling')
    input_shards = f"{os.path.join(input_dir, 'test_data_000.tar')}::{os.path.join(input_dir, 'test_data_001.tar')}"
    args, preprocess_img, tokenizer = build_params(input_shards)
    args.train_data_upsampling_factors = '1::2'
    dataset = get_wds_dataset(args, preprocess_img, is_train=True, tokenizer=tokenizer)
    dataloader = dataset.dataloader

    counts = collections.defaultdict(int)
    for sample in dataloader:
        txts = sample[1]
        for txt in txts:
            counts[txt] += 1
    
    for key, count in counts.items():
        if key.startswith('000'):
            assert count == pytest.approx(TRAIN_NUM_SAMPLES / 20, RTOL), f'{key}, {count}'
        else:
            assert count == pytest.approx(TRAIN_NUM_SAMPLES / 10, RTOL), f'{key}, {count}'
