from run import ss_loader
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import json
import random
from argparse import ArgumentParser
import torch
import numpy as np

from pddl_parser import PDDLParser

norm_x = 420
norm_y = 240
channels = 3
norm_video_shape = (norm_y, norm_x, channels)

def get_video_size(video):

    frame = video[0]
    frame_width = frame.shape[1]
    frame_height = frame.shape[0]
    size = (frame_width, frame_height)
    return size

# A good resized version could be (420, 240)
def gather_video_size(dataset):
    all_sizes = {}
    for _, v in dataset:
        size = get_video_size(v)
        if not size in all_sizes:
            all_sizes[size] = 0
        all_sizes[size] += 1
    return all_sizes

class StatsRecorder:
    def __init__(self, data=None, newmean=None, newstd=None):
        """
        data: ndarray, shape (nobservations, ndimensions)
        """
        if data is not None:
            data = np.atleast_2d(data)
            self.mean = newmean
            self.std = newstd
            self.nobservations = data.shape[0]
        else:
            self.nobservations = 0

    def update(self, data, newmean, newstd):
        """
        data: ndarray, shape (nobservations, ndimensions)
        """
        if self.nobservations == 0:
            self.__init__(data, newmean, newstd)
        else:
            data = np.atleast_2d(data)

            m = self.nobservations * 1.0
            n = data.shape[0]

            tmp = self.mean

            self.mean = m/(m+n)*tmp + n/(m+n)*newmean
            self.std  = m/(m+n)*self.std**2 + n/(m+n)*newstd**2 +\
                        m*n/(m+n)**2 * (tmp - newmean)**2
            self.std  = np.sqrt(self.std)

            self.nobservations += n


def obtain_video_mean_std(dataset):
    cnt = 0
    recorder = StatsRecorder()

    for did, (_, video) in enumerate(dataset):
        print(did)

         # Reshape video into 420, 240
        frame = video[0]
        frame_width = frame.shape[1]
        frame_height = frame.shape[0]
        x_scale = ( norm_x / frame_width)
        y_scale = ( norm_y / frame_height)

        norm_video = []
        for frame in video:
            new_frame = np.resize(frame, norm_video_shape)
            shape = new_frame.shape
            new_frame = new_frame.reshape(shape[-1], norm_x, norm_y)
            norm_video.append(new_frame)
        norm_video = np.stack(norm_video)
        means = norm_video.mean(axis=(0, 2, 3))
        stds = norm_video.std(axis=(0, 2, 3))
        recorder.update(video, means, stds)

    return recorder.mean, recorder.std


if __name__ == "__main__":

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../../data'))

    parser = ArgumentParser("sth_sth")
    parser.add_argument("--phase", type=str, default='train')
    parser.add_argument("--n-epochs", type=int, default=1000)

    # setup question path
    parser.add_argument("--train_num", type=int, default=5000)
    parser.add_argument("--val_num", type=int, default=1000)
    parser.add_argument("--train_max_obj", type=int, default=5)
    parser.add_argument("--train_max_clause", type=int, default=10)
    parser.add_argument("--test_max_obj", type=int, default=100)
    parser.add_argument("--test_max_clause", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--learning-rate", type=float, default=0.0001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model_layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkclauses")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model_path", type=str, default="")
    parser.add_argument("--data_dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_true")
    parser.add_argument("--gpu", type=int, default=0)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    random.seed(args.seed)

    label_dir = os.path.join(args.data_dir, 'labels')
    video_dir = os.path.join(args.data_dir, '20bn-something-something-v2')
    bbox_dir = os.path.join(args.data_dir, 'bboxes')
    constraint_path = os.path.join(args.data_dir, 'constraints.pddl')
    manual_mapping_path = os.path.join(args.data_dir, 'manual_mapping.json')
    manual_mapping = json.load(open(manual_mapping_path, 'r'))
    template_mapping_path = os.path.join(args.data_dir, "template_mapping.json")
    scl_dir = os.path.join(args.data_dir, 'scl')
    common_scl_path = os.path.join(args.data_dir, 'common.scl')

    parser = PDDLParser()
    constraints = parser.parse(''.join(open(constraint_path, 'r').readlines()))
    train_dataset_path = os.path.join(data_dir, f"train_{args.train_num}.json")
    valid_dataset_path = os.path.join(data_dir, f"valid_{args.val_num}.json")
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    action2template = json.load(open(template_mapping_path, 'r'))
    template2action = {}
    for action, template in action2template.items():
        template2action[template] = action

    train_dataset, test_dataset, train_loader, test_loader = ss_loader(train_dataset_path, valid_dataset_path, video_dir, args.batch_size, pairs=None)
    m, s = obtain_video_mean_std(train_dataset)
    print(f"the mean is : {m}, the std is: {s}")
