import cv2
import numpy as np
from .classes import get_split_classes, filter_classes
import torch
import random
import argparse
from typing import List
from torch.utils.data.distributed import DistributedSampler
import json
import os
from torch.utils.data.dataloader import default_collate
import torch.nn.functional as F

from src.dataset.ytvos_dataset import YTVOSStandard, YTVOSEpisodic
from src.dataset.ytvos_transform import TrainTransform, TestTransform

ds_factor = 8.0

def get_train_loader(args: argparse.Namespace,
                     return_paths: bool = False) -> torch.utils.data.DataLoader:
    """
        Build the train loader. This is a standard loader (not episodic)
    """
    assert args.train_split in [0, 1, 2, 3]
    collate_fn = default_collate
    if args.train_name == "ytvis":
        if hasattr(args, 'episodic_train') and args.episodic_train:
            train_data = YTVOSEpisodic(transform=TrainTransform(args.image_size),
                                       train=True, args=args)
        else:
            train_data = YTVOSStandard(transform=TrainTransform(args.image_size),
                                       train=True,
                                       args=args)

    if args.distributed:
        world_size = torch.distributed.get_world_size()
        train_sampler = DistributedSampler(train_data) if args.distributed else None
        batch_size = int(args.batch_size / world_size) if args.distributed else args.batch_size
    else:
        train_sampler = None
        batch_size = args.batch_size

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True,
                                               collate_fn=collate_fn)
    print('#############################', len(train_loader), ' ', batch_size, ' ', len(train_data))
    return train_loader, train_sampler

def get_val_loader(args: argparse.Namespace, split_type: str='val') -> torch.utils.data.DataLoader:
    """
        Build the episodic validation loader.
    """
    assert args.test_split in [0, 1, 2, 3, -1, 'default']

    if args.temporal_episodic_val > 0: ############# Episodic Temporal Datasets
        val_sampler = None
        if args.temporal_episodic_val == 3:
            val_transform = TestTransform(args.image_size)
            val_data = YTVOSEpisodic(transform=val_transform,
                                     train=False,
                                     args=args)

        workers = 1 if args.workers > 0 else args.workers
        val_loader = torch.utils.data.DataLoader(val_data,
                                                 batch_size=1,
                                                 shuffle=False,
                                                 num_workers=workers,
                                                 pin_memory=True,
                                                 sampler=val_sampler)

    return val_loader, val_transform
