import sys
sys.path.append(sys.path[0]+r"/../")
import os
import numpy as np
import torch
import tyro
from dataclasses import dataclass, asdict
from typing import Literal

from mld.train_mld import create_gaussian_diffusion, load_mld
from mld.train_mvae import load_mvae
from model.text_encoder import CLIPTextEncoder, CLIPTextEncoderV2, CLIPTextEncoderV3

from datetime import datetime
from utils.fixseed import fixseed
from evaluation.metrics import calculate_mpjpe
from evaluation.metrics_intergen import *
from collections import OrderedDict
from visualize.plot_scripts import *
from os.path import join as pjoin
from tqdm import tqdm


@dataclass
class EvalModelArgs:
    NAME: str = 'InterCLIP'
    NUM_LAYERS: int = 8
    NUM_HEADS: int = 8
    DROPOUT: float = 0.1
    INPUT_DIM: int = 258
    LATENT_DIM: int = 1024
    FF_SIZE: int = 2048
    ACTIVATION: str = 'gelu'

    # MOTION_REP: global
    FINETUNE: bool = False
    
    eval_checkpoint: str = './evaluation/eval_model/InterHuman/interclip.ckpt'
    process_mode: int = 0


@dataclass
class EvalArgs:
    eval_model_args: EvalModelArgs = EvalModelArgs()
    eval_mode: Literal['debug', 'fast', 'final'] = 'debug'
    extrapolation: bool = False
    scenario: str = None
    
    load_from_file: bool = False
    load_dir: str = ''
    generate_parallel: bool = False
    
    cuda: bool = True
    device: int = 7
    seed: int = 0

    dataset: str = 'interhuman_d262'
    
    batch_size: int = 96

    denoiser_checkpoint: str = ''
    mvae_checkpoint: str = ''
    
    respacing: str = ''

    guidance_param: float = 1.0
    """classifier-free guidance parameter for diffusion sampling"""
    export_smpl: int = 0
    zero_noise: int = 0
    use_predicted_joints: int = 0

    fix_floor: int = 0

    min_length: int = 15
    max_length: int = 300
    
    pre_max_len: int = 38
    
    cut_length: int = 0
    
    react_gen: bool = False
   
 
def evaluate_matching_score(motion_loaders, file):
    match_score_dict = OrderedDict({})
    R_precision_dict = OrderedDict({})
    activation_dict = OrderedDict({})
    # print(motion_loaders.keys())
    print('========== Evaluating MM Distance ==========')
    for motion_loader_name, motion_loader in motion_loaders.items():
        all_motion_embeddings = []
        score_list = []
        all_size = 0
        mm_dist_sum = 0
        top_k_count = 0
        # print(motion_loader_name)
        with torch.no_grad():
            for idx, batch in tqdm(enumerate(motion_loader)):
                if 'interhuman_d262' in args.dataset:
                    text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(batch, gt_dataset_eval.primitive_utility)
                elif 'interx' in args.dataset:
                    text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(batch)
                # print(text_embeddings.shape)
                # print(motion_embeddings.shape)
                dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
                                                     motion_embeddings.cpu().numpy())
                # print(dist_mat.shape)
                mm_dist_sum += dist_mat.trace()

                argsmax = np.argsort(dist_mat, axis=1)
                # print(argsmax.shape)

                top_k_mat = calculate_top_k(argsmax, top_k=3)
                top_k_count += top_k_mat.sum(axis=0)

                all_size += text_embeddings.shape[0]

                all_motion_embeddings.append(motion_embeddings.cpu().numpy())

            all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
            mm_dist = mm_dist_sum / all_size
            R_precision = top_k_count / all_size
            match_score_dict[motion_loader_name] = mm_dist
            R_precision_dict[motion_loader_name] = R_precision
            activation_dict[motion_loader_name] = all_motion_embeddings

        print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}')
        print(f'---> [{motion_loader_name}] MM Distance: {mm_dist:.4f}', file=file, flush=True)

        line = f'---> [{motion_loader_name}] R_precision: '
        for i in range(len(R_precision)):
            line += '(top %d): %.4f ' % (i+1, R_precision[i])
        print(line)
        print(line, file=file, flush=True)

    return match_score_dict, R_precision_dict, activation_dict


def evaluate_fid(groundtruth_loader, activation_dict, file):
    eval_dict = OrderedDict({})
    gt_motion_embeddings = []
    print('========== Evaluating FID ==========')
    with torch.no_grad():
        for idx, batch in tqdm(enumerate(groundtruth_loader)):
            if 'interhuman_d262' in args.dataset:
                motion_embeddings = eval_wrapper.get_motion_embeddings(batch, gt_dataset_eval.primitive_utility)
            elif 'interx' in args.dataset:
                motion_embeddings = eval_wrapper.get_motion_embeddings(batch)
            gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
    gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
    gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings, emb_scale)

    # print(gt_mu)
    for model_name, motion_embeddings in activation_dict.items():
        mu, cov = calculate_activation_statistics(motion_embeddings, emb_scale)
        # print(mu)
        fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
        print(f'---> [{model_name}] FID: {fid:.4f}')
        print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
        eval_dict[model_name] = fid
    return eval_dict


def evaluate_diversity(activation_dict, file):
    eval_dict = OrderedDict({})
    print('========== Evaluating Diversity ==========')
    for model_name, motion_embeddings in activation_dict.items():
        diversity = calculate_diversity(motion_embeddings, diversity_times, emb_scale, divide_by)
        eval_dict[model_name] = diversity
        print(f'---> [{model_name}] Diversity: {diversity:.4f}')
        print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
    return eval_dict


def evaluate_multimodality(mm_motion_loaders, file):
    eval_dict = OrderedDict({})
    print('========== Evaluating MultiModality ==========')
    for model_name, mm_motion_loader in mm_motion_loaders.items():
        mm_motion_embeddings = []
        with torch.no_grad():
            for idx, batch in enumerate(mm_motion_loader):
                # (1, mm_replications, dim_pos)
                if len(batch) == 4:
                    if data_args.interaction:
                        for person in ['person1', 'person2']:
                            for key, value in batch[2][person].items():
                                if isinstance(value, torch.Tensor):
                                    batch[2][person][key] = value[0]
                                elif isinstance(value, list):
                                    batch[2][person][key] = value[0]
                                else:
                                    batch[2][person][key] = value
                    else:
                        for key, value in batch[2].items():
                            if isinstance(value, torch.Tensor):
                                batch[2][key] = value[0]
                            elif isinstance(value, list):
                                batch[2][key] = value[0]
                            else:
                                batch[2][key] = value
                    batch[3] = batch[3][0]
                if 'interhuman_d262' in args.dataset:
                    motion_embedings = eval_wrapper.get_motion_embeddings(batch, gt_dataset_eval.primitive_utility)
                elif 'interx' in args.dataset:
                    motion_embedings = eval_wrapper.get_motion_embeddings(batch)
                mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
        if len(mm_motion_embeddings) == 0:
            multimodality = 0
        else:
            mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
            multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times, emb_scale, divide_by)
        print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
        print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
        eval_dict[model_name] = multimodality
    return eval_dict


def evaluate_mpjpe(motion_loaders, file):
    eval_dict = OrderedDict({})
    mpjpe = []
    print('========== Evaluating MPJPE ==========')
    gen_dataset = motion_loaders['generated'].dataset
    with torch.no_grad():
        for idx in range(len(gt_dataset_eval)):
            mpjpe_idx = 0
            if 'interhuman_d262' in args.dataset:
                gt_motion = gt_dataset_eval[idx][2]
                gen_motion = gen_dataset[idx][2]
                motion_lens = gt_dataset_eval[idx][3]
            elif 'interx' in args.dataset:
                gt_motion = gt_dataset_eval[idx][4]
                gen_motion = gen_dataset[idx][4]
                motion_lens = gt_dataset_eval[idx][5]
            for person in ['person1', 'person2']:
                gt_joints = gt_motion[person]['joints'][:motion_lens].reshape(motion_lens, -1, 3)
                gen_joints = gen_dataset.primitive_utility.calc_joints_from_features(gen_motion[person], use_predicted_joints=args.use_predicted_joints)
                gen_joints = gen_joints[:motion_lens].reshape(motion_lens, -1, 3).to(gt_joints.device)
                mpjpe_idx += calculate_mpjpe(gt_joints, gen_joints)
            mpjpe_idx /= 2  # average over two persons
            mpjpe.append(mpjpe_idx.cpu().numpy())
    mpjpe = np.mean(mpjpe)
    print(f'---> MPJPE: {mpjpe:.4f}')
    print(f'---> MPJPE: {mpjpe:.4f}', file=file, flush=True)
    eval_dict['generated'] = mpjpe
    return eval_dict


def get_metric_statistics(values):
    mean = np.mean(values, axis=0)
    std = np.std(values, axis=0)
    conf_interval = 1.96 * std / np.sqrt(replication_times)
    return mean, conf_interval


def evaluation(log_file):
    with open(log_file, 'w') as f:
        all_metrics = OrderedDict({'MM Distance': OrderedDict({}),
                                   'R_precision': OrderedDict({}),
                                   'FID': OrderedDict({}),
                                   'Diversity': OrderedDict({}),
                                   'MultiModality': OrderedDict({}),
                                   'MPJPE': OrderedDict({}),
                                   })
        for replication in range(replication_times):
            motion_loaders = {}
            mm_motion_loaders = {}
            
            motion_loaders['ground truth'] = gt_loader_eval
            
            loader_kwargs = {
                'replication': replication
            }
            if args.generate_parallel:
                loader_kwargs['replication_times'] = replication_times
            if args.denoiser_checkpoint != '' and not denoiser_args.load_text_embedding:
                loader_kwargs['text_encoder'] = text_encoder
            loader_kwargs['reaction_gen'] = args.react_gen

            motion_loaders['generated'], mm_motion_loaders['generated'] = eval_motion_loaders(**loader_kwargs)

            print(f'==================== Replication {replication} ====================')
            print(f'==================== Replication {replication} ====================', file=f, flush=True)
            print(f'Time: {datetime.now()}')
            print(f'Time: {datetime.now()}', file=f, flush=True)
            mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(motion_loaders, f)

            print(f'Time: {datetime.now()}')
            print(f'Time: {datetime.now()}', file=f, flush=True)
            fid_score_dict = evaluate_fid(gt_loader_eval, acti_dict, f)

            print(f'Time: {datetime.now()}')
            print(f'Time: {datetime.now()}', file=f, flush=True)
            div_score_dict = evaluate_diversity(acti_dict, f)

            print(f'Time: {datetime.now()}')
            print(f'Time: {datetime.now()}', file=f, flush=True)
            mm_score_dict = evaluate_multimodality(mm_motion_loaders, f)
            
            print(f'Time: {datetime.now()}')
            print(f'Time: {datetime.now()}', file=f, flush=True)
            mpjpe_dict = evaluate_mpjpe(motion_loaders, f)
            
            
            print(f'!!! DONE !!!')
            print(f'!!! DONE !!!', file=f, flush=True)

            for key, item in mat_score_dict.items():
                if key not in all_metrics['MM Distance']:
                    all_metrics['MM Distance'][key] = [item]
                else:
                    all_metrics['MM Distance'][key] += [item]

            for key, item in R_precision_dict.items():
                if key not in all_metrics['R_precision']:
                    all_metrics['R_precision'][key] = [item]
                else:
                    all_metrics['R_precision'][key] += [item]

            for key, item in fid_score_dict.items():
                if key not in all_metrics['FID']:
                    all_metrics['FID'][key] = [item]
                else:
                    all_metrics['FID'][key] += [item]

            for key, item in div_score_dict.items():
                if key not in all_metrics['Diversity']:
                    all_metrics['Diversity'][key] = [item]
                else:
                    all_metrics['Diversity'][key] += [item]

            for key, item in mm_score_dict.items():
                if key not in all_metrics['MultiModality']:
                    all_metrics['MultiModality'][key] = [item]
                else:
                    all_metrics['MultiModality'][key] += [item]
            
            for key, item in mpjpe_dict.items():
                if key not in all_metrics['MPJPE']:
                    all_metrics['MPJPE'][key] = [item]
                else:
                    all_metrics['MPJPE'][key] += [item]

        # print(all_metrics['Diversity'])
        for metric_name, metric_dict in all_metrics.items():
            print('========== %s Summary ==========' % metric_name)
            print('========== %s Summary ==========' % metric_name, file=f, flush=True)

            for model_name, values in metric_dict.items():
                # print(metric_name, model_name)
                mean, conf_interval = get_metric_statistics(np.array(values))
                # print(mean, mean.dtype)
                if isinstance(mean, np.float64) or isinstance(mean, np.float32):
                    print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
                    print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
                elif isinstance(mean, np.ndarray):
                    line = f'---> [{model_name}]'
                    for i in range(len(mean)):
                        line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
                    print(line)
                    print(line, file=f, flush=True)


if __name__ == '__main__':
    args = tyro.cli(EvalArgs)
    fixseed(args.seed)
    
    torch.set_num_threads(4)
    torch.set_num_interop_threads(4)
    
    if 'interhuman_d262' in args.dataset:
        from data_loaders.get_data_interhuman_v3 import get_dataset_motion_loader, get_motion_loader
        from data_loaders.HHI.networks.evaluator_wrapper import EvaluatorModelWrapper
    elif 'interx' in args.dataset:
        from data_loaders.get_data_interx_v3 import get_dataset_motion_loader, get_motion_loader, EvaluatorModelWrapper

    print(f'Eval mode [{args.eval_mode}]')
    if args.eval_mode == 'debug':
        diversity_times = 30
        replication_times = 1
        mm_num_samples = 10
        mm_num_repeats = 3
        mm_num_times = 1
    elif args.eval_mode == 'fast':
        diversity_times = 300
        replication_times = 3
        mm_num_samples = 100
        mm_num_repeats = 10
        mm_num_times = 3
    elif args.eval_mode == 'final': # same as InterGen
        diversity_times = 300
        replication_times = 20
        mm_num_samples = 100
        mm_num_repeats = 30
        mm_num_times = 10
    else:
        raise ValueError()
    
    batch_size = args.batch_size # same as InterGen
    emb_scale = 6
    divide_by = 2
        
    # Load model to be evaluated
    print("Creating model and diffusion...")
    device = 'cuda:{}'.format(args.device) if args.cuda else 'cpu'
    if args.denoiser_checkpoint == '' and args.mvae_checkpoint != '':
        vae_args, vae_model = load_mvae(args.mvae_checkpoint, device)
        denoiser_args, denoiser_model, data_args = None, None, vae_args.data_args
        diffusion, diffusion_args = None, None
        text_encoder = None
        data_args.padding = vae_args.padding
        data_args.interaction = True
    else:
        denoiser_args, denoiser_model, vae_args, vae_model, data_args = load_mld(args.denoiser_checkpoint, device, args.pre_max_len)
        data_args.interaction = True
        diffusion_args = denoiser_args.diffusion_args
        diffusion_args.respacing = args.respacing
        print('diffusion_args:', asdict(diffusion_args))
        diffusion = create_gaussian_diffusion(diffusion_args)
        if not denoiser_args.load_text_embedding:
            if denoiser_args.text_encoder_version == "v1":
                text_encoder = CLIPTextEncoder(denoiser_args.clip_version, clip_device=device)
            elif denoiser_args.text_encoder_version == "v2":
                text_encoder = CLIPTextEncoderV2(denoiser_args.clip_version, clip_final_proj=denoiser_args.clip_final_proj, clip_device=device)
            elif denoiser_args.text_encoder_version == "v3":
                text_encoder = CLIPTextEncoderV3(denoiser_args.clip_version, clip_final_proj=denoiser_args.clip_final_proj, clip_device=device)
            text_encoder.to(device)
            text_encoder_ckpt = torch.load(args.denoiser_checkpoint, map_location=device)
            text_encoder.load_state_dict(text_encoder_ckpt['text_encoder_state_dict'])
            text_encoder.eval()
    
    if args.dataset == 'interx':
        batch_size = 32
        args.min_length = 24
        args.max_length = 150
        data_args.max_text_len = 35
        data_args.unit_length = 4
        emb_scale = 1
        divide_by = 1

    # load ground truth dataset
    split = 'test' if args.dataset != "babel" else "val"
    mode = 'merged' if data_args.interaction else 'sep'
    data_args.min_length = args.min_length
    data_args.max_length = args.max_length
    
    suffix = ''
    if denoiser_args is not None and denoiser_args.use_pre_latent:
        suffix = f"_premaxlen{args.pre_max_len}"
    if args.cut_length > 0:
        suffix += f"_cutlen{args.cut_length}"
    if args.react_gen:
        suffix += f"_reactgen"
    suffix += f"_gs{args.guidance_param}"
    ckpt_name = os.path.splitext(os.path.basename(args.denoiser_checkpoint))[0]
    if ckpt_name == 'checkpoint_300000':
        ckpt_name = 'ckpt_300k'
    suffix += f"_{ckpt_name}"
    
    if args.load_from_file:
        if args.load_dir == '':
            if args.denoiser_checkpoint == '' and args.mvae_checkpoint != '':
                ckpt_dir = os.path.dirname(args.mvae_checkpoint)
                args.load_dir = os.path.join(ckpt_dir, f"generated_motion_mvae{suffix}")
            else:
                ckpt_dir = os.path.dirname(args.denoiser_checkpoint)
                args.load_dir = os.path.join(ckpt_dir, f"generated_motion{suffix}")
    

    eval_motion_loaders = lambda **kwargs: get_motion_loader(args,
                                            batch_size,
                                            denoiser_model,
                                            denoiser_args, 
                                            diffusion,
                                            diffusion_args,
                                            vae_model,
                                            vae_args,
                                            data_args,
                                            gt_dataset_eval,
                                            device,
                                            mm_num_samples,
                                            mm_num_repeats,
                                            **kwargs,
                                            )
    
    # load ground truth dataset
    text_sep = denoiser_args.text_sep if denoiser_args is not None else False
    clip_version = denoiser_args.clip_version if denoiser_args is not None else 'ViT-B/32'
    load_text_embedding = denoiser_args.load_text_embedding if denoiser_args is not None else False
    use_indi_text = denoiser_args.use_indi_text if denoiser_args is not None else False
    gt_loader_eval, gt_dataset_eval = get_dataset_motion_loader(data_args, batch_size, device, split, mode, 
                                                                text_sep, 
                                                                cut_length=args.cut_length, 
                                                                clip_version=clip_version, 
                                                                load_text_embedding=load_text_embedding, 
                                                                use_indi_text=use_indi_text)
        
    if 'interhuman_d262' in args.dataset:
        evalmodel_cfg = args.eval_model_args
        eval_wrapper = EvaluatorModelWrapper(evalmodel_cfg, device)
    elif 'interx' in args.dataset:
        from utils.get_opt import get_opt
        evalmodel_cfg = get_opt("./evaluation/eval_model/InterX/checkpoints/Comp_v6_KLD01/opt.txt", device, complete=False)
        eval_wrapper = EvaluatorModelWrapper(evalmodel_cfg)
    
    if args.denoiser_checkpoint == '' and args.mvae_checkpoint != '':
        log_file = f'./evaluation/eval_results_mvae/{os.path.basename(os.path.dirname(args.mvae_checkpoint))}_{args.eval_mode}{suffix}.log'
    else:
        log_file = f'./evaluation/eval_results_modified_dataloader/{os.path.basename(os.path.dirname(args.denoiser_checkpoint))}_{args.eval_mode}{suffix}.log'
    evaluation(log_file)
    
    # delete generated files
    import shutil
    from pathlib import Path
    folder = Path(args.load_dir)
    try:
        shutil.rmtree(folder)
    except FileNotFoundError:
        pass
