import torch
import argparse
import yaml
import math
import os
import time
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.nn import functional as F
from math import ceil
import numpy as np
from torch.utils.data import DistributedSampler, RandomSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.metric import Metrics
import gc
from config import ex
from utils.utils import fix_seeds, setup_cudnn, get_logger, cal_flops
from models.audio_video_model import AudioVideoModelWithMT
from dataset.ks import KineticsSound
from dataset.ave import AVE
from dataset.cremad import CREMAD
import pprint

torch.multiprocessing.set_sharing_strategy('file_system')

@torch.no_grad()
def evaluate(model, dataloader, device, num_classes, loss_fn=None, loss_fn2=None, loss_fn2_weight=0.0, is_mmimdb=False, enable_mt=False):
    # Clean Memory
    torch.cuda.empty_cache()
    gc.collect()

    metric = Metrics()
    print('Evaluating...')
    model.eval()
    n_classes = num_classes

    test_loss = 0.0
    total_ce_loss = 0.0
    if enable_mt:
        total_alignment_loss = 0.0
    iter = 0

    for sample in tqdm(dataloader):
        video_audio = [sample['video'].to(device), sample['audio'].to(device)]
        lbl = sample['label'].to(device)
        missing_type = sample['missing_type'].to(device)
        logits_fused, real_tokens, estimated_tokens, fused_features = model(video_audio, missing_type)
        
        _, predicted = torch.max(logits_fused, 1)
        metric.update(predicted, lbl)
        
        if loss_fn is not None:
            loss = loss_fn(logits_fused, lbl) 

            test_loss += loss.item()
            total_ce_loss += loss.item()
        if loss_fn2 is not None:
            if enable_mt and len(real_tokens) > 0:
                align_loss = loss_fn2(torch.stack(real_tokens), torch.stack(estimated_tokens))
            else:
                align_loss = loss_fn2(torch.zeros(1, 768), torch.zeros(1, 768))
        
            if enable_mt:
                test_loss += loss_fn2_weight*align_loss.item()
                total_alignment_loss += loss_fn2_weight*align_loss.item()

        iter += 1
    
    test_loss /= iter
    total_ce_loss /= iter+1
    if enable_mt:
        total_alignment_loss /= iter+1

    # Calculate evaluation metrics
    all_scores = metric.compute_score(prefix='test_')
    if loss_fn is not None:
        all_scores.update({'test_loss': test_loss, 'test_ce_loss': total_ce_loss})
    if loss_fn2 is not None:
        all_scores.update({'test_ce_loss': total_ce_loss, 'test_alignment_loss': total_alignment_loss})

    # Clean Memory
    metric.reset()
    torch.cuda.empty_cache()
    gc.collect()

    return all_scores


@ex.automain
def main(_config):
    print("Experiment Configurations:")
    print(_config)
    fix_seeds(_config["seed"])
    setup_cudnn()

    num_workers = _config['num_workers']
    device = torch.device(_config['device'])
    num_classes = _config['class_num']
    max_text_len = _config['max_text_len']
    batch_size = _config['batch_size']
    wandb_exp_name = _config["wandb_exp_name"]
    enable_lora = _config['enable_lora']
    enable_mt = _config['enable_mt']

    # construct missing modality info
    missing_info = {
        'ratio' : _config["missing_ratio"],
        'type' : _config["missing_type"],
        'both_ratio' : _config["both_ratio"],
        'missing_table_root': _config["missing_table_root"],
        'simulate_missing' : _config["simulate_missing"],
        'only_paired' : _config["only_paired"]
    }

    # Dataset
    if _config['exp_name'] == "finetune_ks":
        valset = KineticsSound(
            split="test",
            dataset_root_dir=_config['data_dir'],
            missing_info=missing_info,
        ) 
    elif _config['exp_name'] == "finetune_ave":
        valset = AVE(
            split="test",
            dataset_root_dir=_config['data_dir'],
            missing_info=missing_info,
        ) 
    elif _config['exp_name'] == "finetune_cremad":
        valset = CREMAD(
            split="test",
            dataset_root_dir=_config['data_dir'],
            missing_info=missing_info,
        ) 
    
    model_path = Path(_config['model_path'])
    if not model_path.exists():
        raise FileNotFoundError
    print(f"Evaluating {model_path}...")

    exp_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    eval_path = os.path.join(os.path.dirname(_config['model_path']), 'eval_{}.txt'.format(exp_time))

    model = AudioVideoModelWithMT(
        num_classes, 
        r=_config['r'],
        lora_alpha=_config['lora_alpha'],
        lora_dropout=_config['lora_dropout'],
        vit_target_modules = _config['vit_target_modules'],
        ast_target_modules = _config['ast_target_modules'],
        enable_lora=enable_lora,
        enable_mt=enable_mt,
    )
    # Load the state dictionary
    state_dict = torch.load(str(model_path), map_location='cpu')
    msg = model.load_state_dict(state_dict)
    print(msg)
    model = model.to(device)

    sampler_val = None
    valloader = DataLoader(valset, batch_size=_config['batch_size'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)
    loss_fn = nn.CrossEntropyLoss()
    alignment_loss = nn.MSELoss()
    
    test_scores = evaluate(
        model, valloader, device, num_classes, 
        loss_fn=loss_fn, 
        loss_fn2=alignment_loss if enable_mt else None, 
        loss_fn2_weight=_config['mt_alignment_loss_weight'], 
        is_mmimdb=False, 
        enable_mt=enable_mt,
    )
    pprint.pprint(test_scores)

    with open(eval_path, 'a+') as f:
        f.writelines(_config['model_path'])
        f.write("\n============== Eval on {} images =================\n".format(len(valset)))
        f.write("\n")
        f.write(str(_config))
        f.write("\n")
        f.write(str(_config["missing_type"]))
        f.write("\n")
        f.write(str(_config["missing_ratio"]))
        f.write("\n")
        f.write(str(test_scores))
        