"""
Load single dataset for intra-domain experiments with color modality, or multi-modalities (i.e., color, depth, ir).
"""

import os, torch, random
import util.utils_FAS as utils
from dassl.data.datasets import DATASET_REGISTRY
from .wrapper import FAS_RGB, FAS_RGB_VAL

class DatumXY:
    """Data instance which defines the basic attributes.
    Args:
        impath_x (str): image path of fake.
        impath_y (str): image path of live.
        label (int): class label.
        classname (str): class name.
    """

    def __init__(self, impath_x="", impath_y="", label=-1, classname="", video=""):
        assert isinstance(impath_x, str)
        assert isinstance(impath_y, str)
        self._impath_x = impath_x
        self._impath_y = impath_y
        self._label = label
        self._classname = classname
        self._video = video

    @property
    def impath_x(self):
        return self._impath_x
    @property
    def impath_y(self):
        return self._impath_y
    @property
    def label(self):
        return self._label
    @property
    def classname(self):
        return self._classname
    @property
    def video(self):
        return self._video


def read_video(data_root, protocol, folder, split, sample):
    def video2list(root, txt, folder, split, sample=1):
        # get data from txt
        data_name = txt.split('@')[0]
        with open(os.path.join(root, data_name + '/protocol', txt)) as f:
            lines = f.readlines()
            f.close()

        lines_ = []
        for line in lines:
            video, label = line.strip().split(' ')
            if split == 'train':
                if not utils.check_if_exist(os.path.join(root, video, folder)): folder = 'color'
                if not utils.check_if_exist(os.path.join(root, video, folder)): folder = ''
                frames = os.listdir(os.path.join(root, video, folder))
                if len(frames) == 0: continue
                frames = random.sample(frames, k=min(int(sample), len(frames)))
                for frame in frames:
                    impath = os.path.join(root, video, folder, frame)
                    lines_.append((impath, int(label)))
            else:
                if not utils.check_if_exist(os.path.join(root, video, folder)): folder = 'color'
                if not utils.check_if_exist(os.path.join(root, video, folder)): folder = ''
                frames = os.listdir(os.path.join(root, video, folder))
                if len(frames) == 0:continue
                frames = [frames[0], frames[-1]]
                pairs = []
                for frame in frames:
                    impath = os.path.join(root, video, folder, frame)
                    if not utils.check_if_exist(impath): impath = impath.replace('.jpg', '.png')
                    pairs.append(impath)
                pairs.append(video)
                pairs.append(int(label))
                lines_.append(tuple(pairs))

        # data balance to 1:1
        if split == 'train':
            lives, fakes = [], []
            for line in lines_:
                impath, label = line
                if label == 0:
                    lives.append(line)
                else:
                    fakes.append(line)
            insert = len(fakes) - len(lives)
            if insert > 0:
                for _ in range(insert):
                    lives.append(random.choice(lives))
            else:
                for _ in range(-insert):
                    fakes.append(random.choice(fakes))

            assert len(lives) == len(fakes)
            return lives, fakes
        else:
            return lines_

    ################
    items = []
    if split == 'train':
        lives_list, fakes_list = video2list(data_root, protocol + '_video_' + split + '.txt', folder, split, sample)
        for i in range(len(fakes_list)):
            item = DatumXY(
                impath_x=fakes_list[i][0],
                impath_y=lives_list[i][0],
                label=-1
            )
            items.append(item)
        print('Load video {} {}={}'.format(protocol, split, len(lives_list)))
        return items
    else:
        impath_label_list = video2list(data_root, protocol + '_video_' + split + '.txt', folder, split, sample)
        for impath1, impath2, video, label in impath_label_list:
            item = DatumXY(
                impath_x=impath1,
                impath_y=impath2,
                label=label,
                video=video
            )
            items.append(item)
        print('Load video {} {}={}'.format(protocol, split, len(impath_label_list)))
        return items


def read_image(data_root, protocol, split):
    def image2list(root, txt, split):
        data_name = txt.split('@')[0]
        with open(os.path.join(root, data_name + '/protocol', txt)) as f:
            lines = f.readlines()
            f.close()
        lines_ = []

        for line in lines:
            image, label = line.strip().split(' ')
            impath = os.path.join(root, image)
            lines_.append((impath, int(label)))

        # data balance to 1:1
        if split == 'train':
            lives, fakes = [], []
            for line in lines_:
                impath, label = line
                if label == 0:
                    lives.append(line)
                else:
                    fakes.append(line)
            insert = len(fakes) - len(lives)

            if insert >= 0:
                for _ in range(insert):
                    lives.append(random.choice(lives))
            else:
                for _ in range(-insert):
                    fakes.append(random.choice(fakes))

            assert len(lives) == len(fakes)
            return lives, fakes
        else:
            return lines_

    ##########
    items = []
    if split == 'train':
        lives_list, fakes_list = image2list(data_root, protocol + '_image_' + split + '.txt', split)
        for i in range(len(fakes_list)):
            item = DatumXY(
                impath_x=fakes_list[i][0],
                impath_y=lives_list[i][0],
                label=fakes_list[i][1]
            )
            items.append(item)
        print('Load iamge {} {}={}'.format(protocol, split, len(lives_list)))
        return items
    else:
        impath_label_list = image2list(data_root, protocol + '_image_' + split + '.txt', split)
        for impath, label in impath_label_list:
            item = DatumXY(
                impath_x=impath,
                impath_y=impath,
                label=label
            )
            items.append(item)
        print('Load image {} {}={}'.format(protocol, split, len(impath_label_list)))
        return items

def build_dataset(data_root, protocol, is_video, folder='crop'):
    if len(protocol.split('#')) == 2:
        protocol, sample = protocol.split('#')
    else:
        sample = 1

    if is_video:
        data_train = read_video(data_root, protocol, folder, split='train', sample=sample)
        data_dev = read_video(data_root, protocol, folder, split='dev', sample=sample)
        data_test = read_video(data_root, protocol, folder, split='test', sample=sample)
    else:
        data_train = read_image(data_root, protocol, split='train')
        data_dev = read_image(data_root, protocol, split='dev')
        data_test = read_image(data_root, protocol, split='test')

    random.shuffle(data_train), random.shuffle(data_dev), random.shuffle(data_test)
    return data_train, data_test, data_test


@DATASET_REGISTRY.register()
class C_DATA:
    """modals = ['color']
    """
    def __init__(self, cfg):
        train, dev, test = build_dataset(cfg.DATASET.ROOT, cfg.DATASET.PROTOCOL, cfg.DATASET.IS_VIDEO)

        # Build data loader
        train_loader = torch.utils.data.DataLoader(
            FAS_RGB(
                    data_source=train,
                    image_size=cfg.INPUT.SIZE[0],
                    preprocess=cfg.DATASET.PREPROCESS,
                    task='intra'),
                    batch_size=cfg.DATALOADER.TRAIN_X.BATCH_SIZE,
                    shuffle=True,
                    num_workers=cfg.DATALOADER.NUM_WORKERS,
                    drop_last=True,
                    pin_memory=False
                    )
        dev_loader = torch.utils.data.DataLoader(
            FAS_RGB_VAL(
                    data_source=dev,
                    image_size=cfg.INPUT.SIZE[0],
                    preprocess='resize'),
                    batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
                    shuffle=False,
                    num_workers=1,
                    drop_last=False,
                    pin_memory=False
                    )
        test_loader = torch.utils.data.DataLoader(
            FAS_RGB_VAL(
                    data_source=test,
                    image_size=cfg.INPUT.SIZE[0],
                    preprocess='resize'),
                    batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
                    shuffle=False,
                    num_workers=1,
                    drop_last=False,
                    pin_memory=False
                    )
        
        self.train_loader = train_loader
        self.dev_loader = dev_loader
        self.test_loader = test_loader
        self.lab2cname = {0: 'live', 1: 'fake'}
        self.classnames = ['live', 'fake']
        self.templates = [
            'This is an example of a {} face',
            'This is a {} face',
            'This is how a {} face looks like',
            'A photo of a {} face',
            'Is not this a {} face ?',
            'A printout shown to be a {} face'
        ]

