import os
import clip
import pdb
import json
import math
import shutil

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, CenterCrop


import numpy as np
import argparse
import torch.nn as nn
from easydict import EasyDict as edict

import timm
from standard_video_dataset import StandardVidoDataset, standard_collate_fn
import time
import tqdm
import pandas as pd
from cal_ssim import cal_ssim_dist
from vbench.third_party.RAFT.core.raft import RAFT
from vbench.third_party.RAFT.core.utils_core.utils import InputPadder
from viclip import get_viclip, retrieve_text, _frame_from_video, frames2tensor, get_vid_feat
from PIL import Image
import subprocess


def parse_args():
    parser = argparse.ArgumentParser(description="Distributed inference with a pretrained DINO model on DiDemo dataset.")
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--video_dir', type=str, required=True,
                        help='Directory containing video files.')
    parser.add_argument('--save_dir', type=str, required=True,
                        help='Directory to save extracted features.')
    parser.add_argument('--raft_model_path', type=str, default='vbench/third_party/RAFT/models/raft-things.pth')
    parser.add_argument('--dataset', type=str, default='msrvtt')
    return parser.parse_args()


import pandas as pd
import json

def json_to_xlsx(json_file, xlsx_file):
    df = pd.read_json(json_file)
    with pd.ExcelWriter(xlsx_file, engine='openpyxl') as writer:
        df.to_excel(writer, index=False)


def merge_json_files(input_folder, output_file):
    """ Merges multiple JSON files into a single JSON file after all asynchronous writing has been completed. """
    all_data = []
    with open(output_file, 'w') as outfile:
        outfile.write('')
    for filename in os.listdir(input_folder):
        if filename.endswith(".json"):
            file_path = os.path.join(input_folder, filename)
            with open(file_path, 'r') as file:
                for line in file:
                    data = json.loads(line)
                    all_data.extend(data)
    with open(output_file, 'w') as outfile:
        json.dump(all_data, outfile, indent=4)

def save_dicts_to_excel(dicts, filename):
    df_list = pd.DataFrame(dicts)

    # Save the DataFrame to an Excel file
    df_list.to_excel(filename, index=False)

    print(f"Dictionaries saved to '{filename}'.")


def load_and_merge(pth_folder, save_folder):
    all_datas = []
    for file in os.listdir(pth_folder):
        if file.endswith(".pth"):
            path = os.path.join(pth_folder, file)
            all_datas.extend(torch.load(path, map_location='cpu'))
    save_dicts_to_excel(all_datas, os.path.join(save_folder, 'merged_results.xlsx'))


def cal_info_variance(model, video_data, video_lengths):
    features = model(video_data)
    
    features = F.normalize(features, dim=-1, p=2)
    features = features.split(video_lengths)
    features_mean = [feat.mean(axis=0, keepdim=True) for feat in features]
    distances = [(1 - F.cosine_similarity(feat_mean[:, None], feat[None], dim=-1)).mean(1).item() for feat, feat_mean in zip(features, features_mean)]

    return distances, features


def build_raft(model_path, rank):
    args_new = edict({"model":model_path, "small":False, "mixed_precision":False, "alternate_corr":False})
    model = RAFT(args_new)
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    new_state_dict = dict()
    for key in state_dict:
        new_state_dict[key.replace('module.', '')] = state_dict[key]
    model.load_state_dict(new_state_dict)
    model = nn.parallel.DistributedDataParallel(model.to(rank), device_ids=[rank])
    model.eval()
    return model

def build_dino(model_name, rank):
    model = timm.create_model(model_name, pretrained=True)
    model = model.to(rank)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    model.eval()
    return model

def build_dinov2(model_name, rank):
    model = torch.hub.load('facebookresearch/dinov2', model_name)
    model = model.to(rank)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
    model.eval()
    return model

def build_viclip(model_name, rank):
    model_cfgs = {
        'viclip-l-internvid-10m-flt': {
            'size': 'l',
            'pretrained': '/path/to/ViClip-InternVid-10M-FLT.pth',
        },
        'viclip-l-internvid-200m': {
            'size': 'l',
            'pretrained': 'xxx/ViCLIP-L_InternVid-200M.pth',
        },
        'viclip-b-internvid-10m-flt': {
            'size': 'b',
            'pretrained': 'xxx/ViCLIP-B_InternVid-FLT-10M.pth',
        },
        'viclip-b-internvid-200m': {
            'size': 'b',
            'pretrained': 'xxx/ViCLIP-B_InternVid-200M.pth',
        },
    }
    cfg = model_cfgs[model_name]
    viclip = get_viclip(cfg['size'], cfg['pretrained'])['viclip']
    viclip = viclip.to(rank)
    viclip = nn.parallel.DistributedDataParallel(viclip, device_ids=[rank])
    viclip.eval()
    return viclip


def cal_flow_strength(model, video_data, video_lengths, video_names):
    results = []
    for i, frames in enumerate(video_data):
        static_score = []
        for image1, image2 in zip(frames[:-1], frames[1:]):
            padder = InputPadder(image1[None].shape)
            image1, image2 = padder.pad(image1[None], image2[None])
            _, flow_up = model(image1, image2, iters=50, test_mode=True)
            max_rad = get_score(image1, flow_up)
            static_score.append(max_rad)
        static_score = sum(static_score) / len(static_score)
        results.append(static_score)
        print(f'Video: {video_names[i]}, Flow: {static_score}')
    return results


def cal_flow_strength_batch(model, video_data, video_lengths, video_names, max_batch=40):
    results = []
    for i, frames in enumerate(video_data):
        static_score = []
        padder = InputPadder(frames.shape)
        image1s, image2s = padder.pad(frames[:-1], frames[1:])
        iter_n = int(math.ceil(image1s.shape[0] / max_batch))
        flow_ups = []
        for bidx in range(iter_n):
            image1 = image1s[bidx * max_batch: (bidx + 1) * max_batch]
            image2 = image2s[bidx * max_batch: (bidx + 1) * max_batch]
            _, flow_up = model(image1.contiguous(), image2.contiguous(), iters=20, test_mode=True)
            flow_ups.append(flow_up)

        flow_ups = torch.cat(flow_ups, dim=0)
        u_ = flow_ups[:, 0]
        v_ = flow_ups[:, 1]
        rad = torch.sqrt(torch.square(u_) + torch.square(v_))
        rad_flat = rad.flatten(1, 2)
        static_score = rad_flat.topk(int(rad_flat.shape[1]*0.05), dim=1)[0].mean().item()
        results.append(static_score)
        print(f'Video: {video_names[i]}, Flow: {static_score}')
    return results

def get_score(img, flo, only_top=False):
    flo = flo[0].permute(1,2,0).cpu().numpy()

    u = flo[:,:,0]
    v = flo[:,:,1]
    rad = np.sqrt(np.square(u) + np.square(v))
    
    h, w = rad.shape
    rad_flat = rad.flatten()

    if only_top:
        cut_index = int(h * w * 0.05)
        max_rad = np.mean(abs(np.sort(-rad_flat))[:cut_index])
    else:
        max_rad = np.mean(abs(np.sort(-rad_flat)))

    return max_rad.item()

def build_clip(model_name, rank):
    model_clip, process_clip = clip.load(model_name, device=rank)
    model_clip = nn.parallel.DistributedDataParallel(model_clip, device_ids=[rank])
    model_clip.eval()
    return model_clip, process_clip

def cal_dino_segment_dist(model, video_data, video_lengths):
    features = model.module.get_intermediate_layers(video_data, n=1)[0]
    features = F.normalize(features, dim=-1, p=2)
    features = features.split(video_lengths)
    distances = []
    for feat in features:
        acf = [calc_acf(feat, k) for k in range(feat.shape[0] // 8, feat.shape[0])]
        acf = sum(acf) / len(acf)
        distances.append(acf)
    return distances

def cal_viclip_segment_dist(model, video_data, video_lengths, block_lengths_ratio=[0.5, 0.25]):
    distances = []
    for video in video_data.split(video_lengths):
        dist_obj = 0
        for r in block_lengths_ratio:
            segments = get_segments(video, r)
            segments = to_same_size(segments)
            try:
                seg_feats = get_vid_feat(segments, model.module)
            except BaseException as e:
                print(e)
                pdb.set_trace()
            dist_obj += cal_segment_dist(seg_feats).item()
        distances.append(1 - dist_obj / len(block_lengths_ratio))
    return distances

def get_segments(video, block_ratio):
    total_frames = video.shape[0]
    segment_length = int(total_frames * block_ratio)

    num_segments = int((total_frames + segment_length - 1) // segment_length)
    
    segments = []
    
    for i in range(num_segments):
        start = i * segment_length
        end = min(start + segment_length, total_frames)
        segment = video[start:end]
        
        if segment.size(0) < segment_length:
            repeats = segment_length // segment.size(0) + 1
            segment = torch.repeat_interleave(segment, repeats=repeats, dim=0)[:segment_length]
        
        segments.append(segment)
    
    return torch.stack(segments)

def to_same_size(segments, target_size=8):
    batch_size, segment_length, *rest_dims = segments.size()

    if segment_length > target_size:
        indices = torch.linspace(0, segment_length - 1, steps=target_size).long()
        resized_segments = segments[:, indices]
    else:
        repeats_needed = target_size - segment_length
        last_frame = segments[:, -1:]
        repeated_last_frames = last_frame.repeat(1, repeats_needed, *[1] * len(rest_dims))
        resized_segments = torch.cat((segments, repeated_last_frames), dim=1)

    return resized_segments

def cal_segment_dist(features):
    features = F.normalize(features, dim=-1, p=2)
    sims = features @ features.T
    mask = torch.triu(torch.ones_like(sims, dtype=torch.bool))

    sims = torch.masked_select(sims, mask)
    return sims.mean()

def calc_acf(features, k):
    acf = features[:-k] * features[k:]
    acf = acf.sum(-1).mean()
    return (1-acf).abs().item()

def cal_frechet(block_stats):
    distances = []
    for i in range(len(block_stats)-1):
        dist = frechet_distance(*block_stats[i], *block_stats[i+1])
        distances.append(dist.item())
    return sum(distances) / len(distances)
    
def cal_frechet_distance(model, video_data, video_lengths, block_lengths_ratio, features=None):
    if features is None:
        features = model(video_data)
        features = F.normalize(features, dim=-1, p=2)
        features = features.split(video_lengths)
    distances = []
    for i, feat in enumerate(features):
        block_stats = process_blocks(feat, [int(r * video_lengths[i]) for r in block_lengths_ratio])
        distance = []
        for stats in block_stats:
            distance.append(cal_frechet(stats))
        distances.append(sum(distance) / len(distance))
    return distances

def process_blocks(features, block_lengths):
    results = []
    for block_length in block_lengths:
        num_blocks = len(features) // block_length
        block_stats = []
        for i in range(num_blocks):
            block_start = i * block_length
            block_end = block_start + block_length
            block_features = features[block_start:block_end]
            mu, sigma = calculate_statistics(block_features)
            block_stats.append((mu, sigma))
        results.append(block_stats)
    return results

def calculate_covariance_matrix(features):
    mean_centered = features - features.mean(dim=0)

    cov_matrix = mean_centered.T @ mean_centered / (features.size(0) - 1)
    
    return cov_matrix


def calculate_statistics(features):
    mu = torch.mean(features, axis=0)
    sigma = calculate_covariance_matrix(features)
    return mu, sigma

def matrix_sqrt(matrix):
    u, s, v = torch.svd(matrix)
    sqrt_diag = torch.diag(torch.sqrt(s))
    return u @ sqrt_diag @ v.t()


def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    diff = mu1 - mu2
    term1 = diff.dot(diff)
    term2 = torch.trace(sigma1 + sigma2 - 2 * matrix_sqrt(sigma1 @ sigma2))
    
    return term1 + term2

def cal_temporal_entropy(video_paths, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    results = []
    for video_path in video_paths:
        args = [video_path, output_folder]
        if output_folder == '/' or output_folder.startswith('/ '):
            print(f'***************output_folder: {output_folder}***************')
            exit(0)
        result = subprocess.run(['bash', 'cal_temporal_info.sh'] + args, stdout=subprocess.PIPE, text=True)
        results.append(float(result.stdout.strip()))
        print(f'{output_folder} done.')
        shutil.rmtree(output_folder)
    return results


def clear_folder(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print(f'Failed to delete {file_path}. Reason: {e}')

def main(args):
    dist.init_process_group(backend='nccl')
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)
    output_file = f'{args.save_dir}/.tmp/results_{rank}.pth'
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    clear_folder(os.path.dirname(output_file))
    
    with open(output_file, 'w') as file:
        file.write('')

    model_dino = build_dino('vit_small_patch16_224_dino', rank)
    model_dino_v2 = build_dinov2('dinov2_vitl14', rank)
    model_clip, process_clip = build_clip('ViT-L/14', rank)
    model_viclip = build_viclip('viclip-l-internvid-10m-flt', rank)
    process_clip.transforms =  process_clip.transforms[:2] + process_clip.transforms[-1:]
    process_clip.transforms[0] = Resize((336, 336))
    model_flow = build_raft(args.raft_model_path, rank)

    transform = Compose([
        Resize((224, 224)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_no_resize = Compose([
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    dataset = StandardVidoDataset(
        video_dir=args.video_dir,
        transform=transform,
        transform_no_resize=transform_no_resize,
    )  

    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, num_workers=2, collate_fn=standard_collate_fn)
    if rank == 0:
        progress_bar = tqdm.tqdm(total=len(dataloader), position=0, leave=True)

    results = []
    with torch.no_grad():
        for frames in dataloader:
            org_videos = frames['org_videos']

            video_data = frames['videos'].to(rank)
            video_no_resize = [f.to(rank) for f in frames['videos_no_resize']]
            video_names = frames['video_names']
            video_lengths = frames['video_lengths']
            video_paths = frames['video_paths']
            # # Inter-frame level metrics
            flow_strength = cal_flow_strength_batch(model_flow, video_no_resize, video_lengths, video_names)
            ssim_dists, msssim, phash_dists = cal_ssim_dist(video_no_resize, org_videos, video_names)

            # Inter-segment level metrics
            dino_segment_dist = cal_dino_segment_dist(model_dino_v2, video_data, video_lengths, )
            viclip_segment_dist = cal_viclip_segment_dist(model_viclip, video_data, video_lengths, [0.5, 0.25])

            # # Video level metrics
            temporal_info = cal_temporal_entropy(video_paths, f'{args.save_dir}/.tmp-rank-{rank}')
            dino_var_dist, dino_features = cal_info_variance(model_dino, video_data, video_lengths)

            res = [{
                "video_name": video_names[i],
                "ssim": ssim_dists[i],
                "ms-ssim": msssim[i],
                "phash": phash_dists[i],
                "flow": flow_strength[i],
                "dino_segm_dist": dino_segment_dist[i],
                "viclip_segm_dist": viclip_segment_dist[i],
                "info_dino": dino_var_dist[i],
                "temporal_entropy": temporal_info[i],
            } for i in range(len(video_lengths))]
            results.extend(res)
            if rank == 0:
                progress_bar.update(1)

    if rank == 0:
        progress_bar.close()

    torch.save(results, output_file)

    # Wait for all processes to save their data
    dist.barrier()

    # Only the master process performs the merge and save
    if rank == 0:
        load_and_merge(f'{args.save_dir}/.tmp', args.save_dir)
        shutil.rmtree(f'{args.save_dir}/.tmp')

    # It's important to destroy the process group after all operations are complete
    dist.destroy_process_group()


if __name__ == '__main__':
    args = parse_args()
    main(args)
