import argparse
import os
import warnings
from random import randint

import cv2
import einops
import librosa
import numpy as np
import torch
from PIL import Image
from matplotlib import pyplot as plt
from scipy import signal
import torchvision.transforms as transforms
from tqdm import tqdm

from dataset.Base import BaseDataset, get_transform, getIndex

warnings.filterwarnings('ignore')


class FAVDataset(BaseDataset):

    def __init__(self, opt, target_frames=8, audio_only=False, video_only=False):
        BaseDataset.__init__(self, opt)
        self.audio_only = audio_only
        self.video_only = video_only
        self.opt = opt
        self.stage = opt.mode
        self.data_index = getIndex(self.opt.dataroot, 'meta_data.csv', type='csv')
        self.target_frames = target_frames
        if self.stage == 'train':
            self.vect, _, _ = self.get_labeled_video()
        elif self.stage == 'valid':
            _, self.vect, _ = self.get_labeled_video()
        elif self.stage == 'test':
            _, _, self.vect = self.get_labeled_video()
        self.transform = get_transform(self.opt)

    def __getitem__(self, index):
        visual, audio = extract_frames_and_audio(self.vect[index][0], self.audio_only, self.video_only,
                                                 target_frames=self.target_frames)
        if self.vect[index][1] == 'True':
            v_label = 1.0
        else:
            v_label = 0.0
        if self.vect[index][2] == 'True':
            a_label = 1.0
        else:
            a_label = 0.0
        if not self.video_only:
            audio = torch.FloatTensor(audio)
        if not self.audio_only:
            visual = torch.FloatTensor(visual)
        return visual, audio, v_label, a_label, a_label * v_label

    def __len__(self):
        return len(self.vect)

    def get_labeled_video(self):
        root = self.opt.dataroot
        index = self.data_index
        real_real = []
        fake_real = []
        fake_fake = []
        real_fake = []
        # 0.source  1.target1   2.target2
        # 3.method  4.category  5.type
        # 6.race    7.gender    8.name  9.path
        # 10.washed
        for index, iter in tqdm(index.iterrows()):
            if iter[10] == 0:
                continue
            path = os.path.join(root, iter[9].replace('FakeAVCeleb/', ''), iter[8])
            if iter[5] == "RealVideo-RealAudio":
                real_real.append([path, True, True])
            elif iter[5] == "FakeVideo-RealAudio":
                if return_true_with_probability():
                    fake_real.append([path, False, True])
            elif iter[5] == "FakeVideo-FakeAudio":
                if return_true_with_probability():
                    fake_fake.append([path, False, False])
            elif iter[5] == "RealVideo-FakeAudio":
                real_fake.append([path, True, False])
            # get_video_duration(path)
            # extract_frames_and_audio(path)
        return split_data_and_save_to_csv(real_real, fake_real, fake_fake, real_fake)


def get_video_duration(video_path):
    if not os.path.isfile(video_path):
        print(video_path)
        raise FileNotFoundError("")

    cap = cv2.VideoCapture(video_path)

    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    duration = total_frames / fps

    ret, frame = cap.read()
    if frame.shape[0] != frame.shape[1]:
        print(video_path, 'un_squared')
    if frame.shape[0] != 224:
        print(video_path, 'un_formated')
    cap.release()
    if total_frames < 30:
        print(fps, total_frames)
        print(video_path)
    if duration < 1:
        print(duration)
        print(video_path)

    # if fps != 25:
    #     print(fps)
    #     print(video_path)
    return duration


def return_true_with_probability():
    # This method was used in previous work to balance the ratio of positive and negative samples. In this experiment, it was not present.
    return True

def extract_frames_and_audio(video_path, audio_only, video_only, target_frames=16, frame_size=0.080,
                             audio_frame_length=0.064, frame_shift=0.016, target_sr=16000, image_size=224):
    skip = True if target_frames < 8 else False
    preprocess = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    crossmodal = (not video_only) and (not audio_only)
    if crossmodal:
        video_capture = cv2.VideoCapture(video_path)
        fps = video_capture.get(cv2.CAP_PROP_FPS)
        total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_length = 1 / fps
        # Sync
        central_offset = frame_size + (audio_frame_length - frame_shift) / 2
        t = 0.
        while t < central_offset:
            t += frame_length
            ret, _ = video_capture.read()
            if not ret:
                break
        audio_length = int(target_sr * (frame_size * 2 + audio_frame_length - frame_shift)) + 1
        # audio_num_of_frame = int((frame_size * 2) / frame_shift)
        # audio_n_fft = int(audio_frame_length * target_sr)
        # hop_size = int(frame_shift * target_sr)
        visual_frames = np.empty((target_frames, image_size, image_size, 3))
        # audio_frames = np.empty((target_frames, int(audio_n_fft / 2) + 1, audio_num_of_frame, 3), dtype=float)
        audio_frames = []
        y, sr = librosa.load(video_path, sr=target_sr)
        for i in range(target_frames):
            ret, frame = video_capture.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame_rgb)
            frame = preprocess(frame)
            frame = einops.rearrange(frame, 'C H W -> H W C')
            audio_frame = (extract_audio_segment(y, int((t - central_offset) * target_sr), audio_length))
            audio_frames.append(audio_frame)
            # audio_frames[i] = extract_audio_with_stft(audio_frame, win_length=audio_n_fft, hop_size=hop_size)
            visual_frames[i] = frame
            t = t + frame_length
            if skip:
                ret, frame = video_capture.read()
                if not ret:
                    break
                t = t + frame_length
        video_capture.release()
        return visual_frames, audio_frames
    if audio_only:
        time_length = target_frames * frame_shift
        audio_length = int(target_sr * (time_length + 3 * frame_shift)) + 1
        # audio_num_of_frame = target_frames
        # audio_n_fft = int(audio_frame_length * target_sr)
        # hop_size = int(frame_shift * target_sr)
        # audio_frames = np.empty((int(audio_n_fft / 2) + 1, audio_num_of_frame, 3), dtype=float)
        y, sr = librosa.load(video_path, sr=target_sr)
        audio_frame = (extract_audio_segment(y, 0, audio_length))
        # audio_frames = extract_audio_with_stft(audio_frame, win_length=audio_n_fft, hop_size=hop_size)
        return np.zeros((1, 1)), audio_frame
    if video_only:
        video_capture = cv2.VideoCapture(video_path)
        visual_frames = np.empty((target_frames, image_size, image_size, 3))
        for i in range(target_frames):
            ret, frame = video_capture.read()
            if not ret:
                break
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame_rgb)
            # frame = frame.convert('YCbCr')
            # frame = convert_image_channels(frame)
            frame = preprocess(frame)
            frame = einops.rearrange(frame, 'C H W -> H W C')
            # plt.imshow(frame)
            # plt.show()
            # frame = frame.convert('YCbCr')
            visual_frames[i] = frame
            if skip:
                ret, frame = video_capture.read()
                if not ret:
                    break
        video_capture.release()
        return visual_frames, np.zeros((1, 1))


def extract_audio_segment(audio, t, length):
    audio_segment = audio[t: t + length]
    return audio_segment


def convert_image_channels(frame):
    return cv2.cvtColor(frame, cv2.COLOR_BGR2YCrCb)


def extract_audio_with_stft(y, win_length, hop_size):
    p_re_emphasis = 0.97
    num_freq = win_length

    def re_emphasis(x):
        return signal.lfilter([1, -p_re_emphasis], [1], x)

    def _stft(x):
        return librosa.stft(y=x, n_fft=num_freq,
                            center=False,
                            hop_length=hop_size, win_length=win_length,
                            window=signal.windows.blackman)

    D = _stft(re_emphasis(y))
    feature = np.log(abs(D) + np.exp(-80))
    theta = np.angle(D)
    # plt.imshow(feature)
    # plt.show()
    # feature=feature[432:,:]
    return np.stack((feature, np.sin(theta), np.cos(theta)), axis=2)


def split_array(array):
    total_samples = len(array)

    train_len = int(total_samples * 0.7)
    val_len = int(total_samples * 0.)
    test_len = int(total_samples * 0.3)

    train_set = array[:train_len]
    val_set = array[train_len:train_len + val_len]
    test_set = array[train_len + val_len:train_len + val_len + test_len]
    return train_set, val_set, test_set


def split_data_and_save_to_csv(a, b, c, d):
    train_set_a, val_set_a, test_set_a = split_array(a)
    train_set_b, val_set_b, test_set_b = split_array(b)
    train_set_c, val_set_c, test_set_c = split_array(c)
    train_set_d, val_set_d, test_set_d = split_array(d)
    train_set = np.concatenate((train_set_a, train_set_b, train_set_c, train_set_d), axis=0)
    val_set = np.concatenate((val_set_a, val_set_b, val_set_c, val_set_d), axis=0)
    test_set = np.concatenate((test_set_a, test_set_b, test_set_c, test_set_d), axis=0)
    np.random.shuffle(train_set)
    np.random.shuffle(val_set)
    np.random.shuffle(test_set)
    return train_set, val_set, test_set
