import numpy as np
import torch
import math
#from torchvideo import VideoTransform
import os
import subprocess
import cv2
import utils
from sklearn.metrics import accuracy_score

# Utilities

def max_diff(r, g, b):
    return max(abs(mode(r) - mode(g)), abs(mode(g) - mode(b)), abs(mode(r) - mode(b)))

def mode(x):
    values, counts = np.unique(x, return_counts=True)
    if len(counts):
        m = counts.argmax()
        return values[m]
    return -1000 # invalid for colorspace (0-255)

def color_histogram(vid):
    image = (vid.view(3, -1).cpu().numpy() * 255).astype(int)
    chans = (image[0], image[1], image[2])
    return chans

def norm_per_pixel(r, g, b):
    return (np.linalg.norm(r-g) + np.linalg.norm(g-b) + np.linalg.norm(r-b)) / len(r) 

def pad_input_for_l2(X, Xt):
    if X.size() != Xt.size(): # this means a temporal corruption ocurred
        # pad Xt with zero tensors of shape (3, 16, 112, 112) in the batch dimension
        temp = list(Xt.size())
        temp[0] = abs(X.size(0) - Xt.size(0))
        Xout = torch.cat([Xt.cpu(), torch.zeros(*temp, device='cpu')], dim=0)
    else:
        Xout = Xt
    return Xout

class DataHook(object):
    def __init__(self):
        self.results = []

    def __call__(self, X, Xt, y, metadata, transform):
        raise NotImplementedError()

    def update(self, result):
        self.results.append(result)

    def report_results(self, y, preds):
        pass

    def __repr__(self):
        return self.__class__.__name__

class GraynessRatio(DataHook):
    def __init__(self):
        super(GraynessRatio, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform):
        assert X.size(0) == Xt.size(0)
        ratios = []
        for i in range(X.size(0)):
            r, g, b = color_histogram(X[i])
            rt, gt, bt = color_histogram(Xt[i])
            orig = norm_per_pixel(r, g, b)
            aug = norm_per_pixel(rt, gt, bt)
            ratios.append(orig / aug)
        grayness = sum(ratios) / len(ratios)
        return grayness

    def report_results(self, y, preds):
        print("Average 'grayness' ratio:", np.array(self.results).mean())

class IsGray(DataHook):
    def __init__(self):
        super(IsGray, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform=None):
        lb = (128 - 18) * 3
        ub = (128 + 18) * 3
        r, g, b = color_histogram(Xt)
        if max_diff(r, g, b) < 20 and lb <= (mode(r) + mode(g) + mode(b)) <= ub: 
            return True
        else:
            return False

    def report_results(self, y, preds):
        gray = np.array(self.results)
        print("Number gray:", np.count_nonzero(gray))
        print("Number non-gray:", np.count_nonzero(~gray))
        print("Accuracy on gray:", accuracy_score(y[gray], preds[gray]))
        print("Accuracy on non-gray:", accuracy_score(y[~gray], preds[~gray]))

class IFrameLocations(DataHook):
    def __init__(self):
        super(IFrameLocations, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform=None):
        path = metadata[0][0]
        indices = []
        if os.path.isfile(path):
            ffprocess = subprocess.Popen(['ffprobe', '-loglevel', 'quiet', '-select_streams', 'v', '-show_frames', '-show_entries', 'frame=pict_type', '-of', 'csv', path], stdout=subprocess.PIPE)
            grep = subprocess.Popen(['grep', '-n', 'I'], stdin=ffprocess.stdout, stdout=subprocess.PIPE)
            output = subprocess.check_output(['cut', '-d', ':', '-f', '1'], stdin=grep.stdout)
            indices = list(map(lambda x: int(x) - 1, output.decode('utf-8').strip().split()))
        return indices

    def report_results(self, y, preds):
        n_iframes = np.array(list(map(len, self.results)))
        correct = y.numpy() == preds.numpy() 
        print("Average # of I-frames:", np.mean(n_iframes))
        print("Average # of I-frames (correct):", np.mean(n_iframes[correct]))
        print("Average # of I-frames (incorrect):", np.mean(n_iframes[incorrect]))
        intervals = np.array(list(map(lambda x: np.diff(x).tolist(), self.results)))
        print("Average I-frame interval:", np.mean(list(chain(*intervals))))
        print("Average I-frame interval (correct):", np.mean(list(chain(*(intervals[correct])))))
        print("Average I-frame interval (incorrect):", np.mean(list(chain(*(intervals[~correct])))))

class TruncatedL2(DataHook):
    def __init__(self, ndim=5, device='cuda:0'):
        super(TruncatedL2, self).__init__()
        self.ndim = ndim
        self.device = device

    def __call__(self, X, Xt, y, metadata, transform=None):
        if Xt.ndim < self.ndim: return -1
        Xtnew = Xt[:X.size(0)]
        Xorig = X[:Xt.size(0)]
        assert Xtnew.size() == Xorig.size(), (X.size(), Xt.size(), Xtnew.size(), Xorig.size())
        diff = Xtnew.to(self.device) - Xorig.to(self.device)
        dist = diff.pow(2).sum(dim=1).sqrt().mean() / math.sqrt(3) # normalize to 0-1 
        return dist.item()

    def report_results(self, y, preds):
        l2 = np.array(self.results)
        print("Average L2 from original:", l2[l2 >= 0].mean())
        if (y is not None) and (preds is not None):
            readable = (l2 >= 0)
            correct = y.numpy() == preds.numpy() 
            print("L2, Correct:", l2[readable & correct].mean())
            print("L2, Incorrect:", l2[readable & ~correct].mean())

class TruncatedFramewiseL2(DataHook):
    def __init__(self):
        super(TruncatedFramewiseL2, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform=None):
        if Xt.ndim < 5: return -1
        if len(Xt) > len(X):
            Xtnew = Xt[:X.size(0)]
        else:
            Xtnew = Xt
        Xorig = X[:Xt.size(0)]
        assert Xtnew.size() == Xorig.size(), (X.size(), Xt.size(), Xtnew.size(), Xorig.size())
        diff = Xtnew.cuda() - Xorig.cuda()
        dist = diff.pow(2).sum(dim=1).sqrt().mean(dim=(-2, -1)) / math.sqrt(3)
        dist = dist.view(-1)
        return dist.cpu().numpy()

    def report_results(self, y, preds):
        pass

class LoopedL2(DataHook):
    def __init__(self):
        super(LoopedL2, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform=None):
        if len(Xt) == 0: return -1
        Xcopy = X[:]
        Xtcopy = Xt[:]
        while len(Xcopy) < len(Xt):
            Xcopy = torch.cat([Xcopy] * 2, dim=0)
        Xcopy = Xcopy[:Xt.size(0)]
        while len(Xtcopy) < len(X):
            Xtcopy = torch.cat([Xtcopy] * 2, dim=0)
        Xtcopy = Xtcopy[:X.size(0)]
        assert len(Xcopy) == len(Xtcopy)
        diff = Xcopy - Xtcopy
        dist = diff.pow(2).sum(dim=1).sqrt().mean() / math.sqrt(3)
        return dist.item()

    def report_results(self, y, preds):
        l2 = np.array(self.results)
        readable = l2[l2 >= 0]
        print("Average looped L2:", readable.mean())
        correct = y.numpy() == preds.numpy()
        print("Looped L2, Correct:", l2[readable & correct].mean())
        print("Looped L2, Incorrect:", l2[readable & ~correct].mean())

class LoopedFramewiseL2(DataHook):
    def __init__(self):
        super(LoopedFramewiseL2, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform=None):
        if len(Xt) == 0: return -1
        Xcopy = X[:]
        Xtcopy = Xt[:]
        while len(Xcopy) < len(Xt):
            Xcopy = torch.cat([Xcopy] * 2, dim=0)
        Xcopy = Xcopy[:Xt.size(0)]
        while len(Xtcopy) < len(X):
            Xtcopy = torch.cat([Xtcopy] * 2, dim=0)
        Xtcopy = Xtcopy[:X.size(0)]
        assert len(Xcopy) == len(Xtcopy)
        diff = Xtcopy - Xcopy
        dist = diff.pow(2).sum(dim=1).sqrt().mean(dim=(-2, -1)) / math.sqrt(3)
        dist = dist.view(-1)
        return dist.cpu().numpy()

    def report_results(self, y, preds):
        pass

class ExtendedL2(DataHook):
    def __init__(self):
        super(ExtendedL2, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform=None):
        if len(Xt) == 0: return -1 # invalid value for l2
        Xout = pad_input_for_l2(X, Xt)
        diff = Xout.cuda() - X.cuda()
        dist = diff.pow(2).sum(dim=1).sqrt().mean() / math.sqrt(3) # normalize to 0-1 
        return dist.item()

    def report_results(self, y, preds):
        l2 = np.array(self.results)
        print("Average L2 from original:", l2[l2 >= 0].mean())
        readable = l2[l2 >= 0]
        correct = y.numpy() == preds.numpy()
        print("L2, Correct:", l2[readable & correct].mean())
        print("L2, Incorrect:", l2[readable & ~correct].mean())

class ExtendedFramewiseL2(DataHook):
    def __init__(self):
        super(ExtendedFramewiseL2, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform=None):
        if len(Xt) == 0: return -1
        Xout = pad_input_for_l2(X, Xt)
        diff = Xout.cuda() - X.cuda()
        dist = diff.pow(2).sum(dim=1).sqrt().mean(dim=(-2, -1)) / math.sqrt(3)
        dist = dist.view(-1)
        return dist.cpu().numpy()

    def report_results(self, y, preds):
        pass

class SegmentsDropped(DataHook):
    def __init__(self):
        super(SegmentsDropped, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform=None):
        dropped = abs(len(X) - len(Xt))
        return [dropped, len(X)]

    def report_results(self, y, preds):
        segments_dropped = np.array(self.results)
        unreadable = (self.results[:, 0] == self.results[:, 1])
        print("Average segments dropped:", segments_dropped[:, 0].mean())
        print("Average proportion of segments dropped:", (segments_dropped[:, 0] / (segments_dropped[:, 1] + 1e-15)).mean())
        dropped_segments = (segments_dropped[:, 0] > 0)
        print("Number with dropped segments:", np.count_nonzero(dropped_segments))
        print("Number readable with dropped segments:", np.count_nonzero(dropped_segments & ~unreadable))
        print("Accuracy on data w/ dropped segments:", accuracy_score(y[dropped_segments], preds[dropped_segments]))
        print("Accuracy on readable data w/ dropped segments:", accuracy_score(y[dropped_segments & ~unreadable], preds[dropped_segments & ~unreadable]))
        print("Accuracy on data w/o dropped segments:", accuracy_score(y[~dropped_segments], preds[~dropped_segments]))

class HitIFrame(DataHook):
    def __init__(self):
        super(HitIFrame, self).__init__()
        self.unreadable_cache = []

    def __call__(self, X, Xt, y, metadata, transform):
        #assert isinstance(transform, VideoTransform) and hasattr(transform, 'cache')
        hit = metadata[-1][-1]['corrupted_iframe']
        self.unreadable_cache.append((len(Xt) == 0))
        return hit

    def report_results(self, y, preds):
        hit_iframe = np.array(self.results)
        print("Frequency of I-frame hits:", np.count_nonzero(hit_iframe))
        print("Accuracy on I-frame whacked data:", accuracy_score(y[hit_iframe], results['preds'][hit_iframe]))
        print("Accuracy on non I-frame whacked data:", accuracy_score(results['y'][~hit_iframe], results['preds'][~hit_iframe]))
        readable = ~np.array(self.unreadable_cache)
        iframe_readable = readable & hit_iframe
        noiframe_readable = readable & ~hit_iframe
        print("Frequency of readable I-frame hits:", np.count_nonzero(iframe_readable))
        print("Accuracy on readable I-frame whacked data:", accuracy_score(y[iframe_readable], y[iframe_readable]))
        print("Accuracy on readable non-I-frame whacked data:", accuracy_score(y[noiframe_readable], y[noiframe_readable]))

class Unreadable(DataHook):
    def __init__(self):
        super(Unreadable, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform):
        unreadable = (len(Xt) == 0)
        return unreadable

    def report_results(self, y, preds):
        unreadable = np.array(self.results)
        print("Number unreadable:", np.count_nonzero(unreadable))
        print("Accuracy on readable:", accuracy_score(y[~unreadable], preds[~unreadable]))

class SaveLocation(DataHook):
    def __init__(self):
        super(SaveLocation, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform):
        #assert isinstance(transform, VideoTransform) and hasattr(transform, 'cache')
        locs = metadata[-1][-1]['locations']
        return locs

    def report_results(self, y, preds):
        print(len(self.results), "corruptions saved")

class SavePath(DataHook):
    def __init__(self):
        super(SavePath, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform):
        try:
            return metadata[0][0]
        except IndexError as e:
            raise IndexError(str(e) + "; metadata state = " + str(metadata))

    def report_results(self, y, preds):
        print(len(self.results), "files saved")

class Debug(DataHook):
    def __init__(self):
        super(Debug, self).__init__()

    def __call__(self, X, Xt, y, metadata, transform):
        frames = -1
        if os.path.isfile(metadata[0][0]):
            with utils.stderr_suppress():
                cap = cv2.VideoCapture(metadata[0][0])
                frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        return X.size(), Xt.size(), frames

    def report_results(self, y, preds):
        pass
