import argparse
import json
import os
import os.path as osp
from collections import OrderedDict
from copy import deepcopy
from glob import glob

import numpy as np
from tabulate import tabulate

import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument(
    '--zero_shot_dir',
    type=str,
    default=None
)
parser.add_argument(
    '--mae_dir',
    type=str,
    default=None
)
parser.add_argument(
    '--win_size',
    type=int,
    default=None
)
parser.add_argument(
    '--video',
    type=str,
    default=None
)
args = parser.parse_args()

if __name__ == '__main__':
    target_video = args.video
    all_vid = target_video is None
    win_size = args.win_size

    zero_shot_dir = args.zero_shot_dir
    zero_shot_exps = glob(zero_shot_dir + "/*/", recursive=False)

    mae_dir = args.mae_dir
    mae_exps = glob(mae_dir + "/*/", recursive=False)

    train_set = ["0000",  "0001",  "0003",  "0004",  "0005",  "0009",  "0011",  "0012",  "0015",  "0017",  "0019",  "0020"]
    val_set = ["0002",  "0006",  "0007",  "0008",  "0010",  "0013",  "0014",  "0016",  "0018"]
    all_videos = [*train_set, *val_set]

    if all_vid:
        all_zero_shot_results = []
        all_mae_results = []
        array_of_len = []
        max_len = -1
        max_idx = -1
        for i, vid in enumerate(all_videos):
            results = [p for p in zero_shot_exps if vid in p]
            assert len(results) == 1
            fp = open(os.path.join(results[0], 'performance.txt'), 'r')
            out = fp.readlines()
            miou_zero_shot = [float(run) for run in out][(win_size - 1):]
            fp.close()

            results = [p for p in mae_exps if vid in p]
            assert len(results) == 1
            fp = open(os.path.join(results[0], str(win_size) + '_win', 'performance.txt'), 'r')
            out = fp.readlines()
            miou_mae = [float(run) for run in out]
            fp.close()

            array_of_len.append(len(miou_mae))
            if max_len == -1 or len(miou_mae) > max_len:
                max_len = len(miou_mae)
                max_idx = i

            all_zero_shot_results.append(miou_zero_shot)
            all_mae_results.append(miou_mae)

        for i in range(len(all_mae_results)):
            # Append np.nan to arrays which are shorter than max_len
            if i != max_idx:
                all_zero_shot_results[i].extend([np.nan] * (max_len - array_of_len[i]))
                all_mae_results[i].extend([np.nan] * (max_len - array_of_len[i]))
        
        # import ipdb; ipdb.set_trace()

        # Build master array and compute statistics
        all_zero_shot_results = np.array(all_zero_shot_results)
        all_mae_results = np.array(all_mae_results)

        all_delta = all_mae_results - all_zero_shot_results
        all_delta_mean = np.nanmean(all_delta, axis=0)
        all_delta_std = np.nanstd(all_delta, axis=0)

        # Plot mean with std, GP-like
        frames = len(all_delta_mean)
        plt.plot(np.arange(frames) + 1, all_delta_mean, 'b-')
        plt.plot(np.arange(frames) + 1, np.zeros_like(all_delta_mean), 'r-')
        plt.fill_between(np.arange(frames) + 1, all_delta_mean - all_delta_std, all_delta_mean + all_delta_std)
        plt.title('Performance Delta vs. Video Progression (All Videos Avg.)')
        plt.savefig('exp_dir/all_miou_vs_time.png')



    else:
        # import ipdb; ipdb.set_trace()
        results = [p for p in zero_shot_exps if target_video in p]
        assert len(results) == 1
        fp = open(os.path.join(results[0], 'performance.txt'), 'r')
        out = fp.readlines()
        miou_zero_shot = [float(run) for run in out][(win_size - 1):]
        fp.close()

        results = [p for p in mae_exps if target_video in p]
        assert len(results) == 1
        fp = open(os.path.join(results[0], str(win_size) + '_win', 'performance.txt'), 'r')
        out = fp.readlines()
        miou_mae = [float(run) for run in out]
        fp.close()
        
        # Plot zero-shot against video-TTT-MAE
        # import ipdb; ipdb.set_trace()
        assert len(miou_mae) == len(miou_zero_shot), 'need same number of frames for both experiments'
        frames = len(miou_mae)
        plt.plot(np.arange(frames) + 1, miou_mae, 'b-', label='TTT-MAE')
        plt.plot(np.arange(frames) + 1, miou_zero_shot, 'r-', label='Zero-Shot')
        plt.title('mIoU vs. Video Progression (Video ' + target_video + ')')
        plt.savefig('exp_dir/' + target_video + '_miou_vs_time.png')
        

