import torch
import argparse
import yaml
import math
import os
import time
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from tabulate import tabulate
from torch.utils.data import DataLoader
from torch.nn import functional as F
from math import ceil
import numpy as np
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from vit_bert_roberta.utils.metric import Metrics
import gc
from vit_bert_roberta.config import ex
from vit_bert_roberta.utils.utils import fix_seeds, setup_cudnn, get_logger, cal_flops
from vit_bert_roberta.dataset.food101_dataset import FOOD101Dataset
from vit_bert_roberta.dataset.mmimdb_dataset import MMIMDBDataset
from dataset.ks import KineticsSound
from dataset.ave import AVE
from dataset.creamad import CREAMAD
from vit_bert_roberta.models.unimodal import *
import pprint
from pytorch_metric_learning import losses


@torch.no_grad()
def evaluate(model, dataloader, device, num_classes, max_text_len, loss_fn=None, loss_fn2=None, loss_fn2_weight=0.2, is_mmimdb=False, enable_auroc=False, dataset_type=1):
    # Clean Memory
    torch.cuda.empty_cache()
    gc.collect()

    # Init Variables for Performance Calculation
    metric = Metrics()

    print('Evaluating...')
    model.eval()
    n_classes = num_classes
    
    # Processor
    if dataset_type == 1:
        tokenizer = dataloader.dataset.tokenizer
        image_processor = dataloader.dataset.image_processor

    test_loss = 0.0
    total_ce_loss = 0.0
    iter = 0

    for sample in tqdm(dataloader):
        if dataset_type == 1:
            images = image_processor(sample['image'], return_tensors="pt", do_rescale=False).to(device)
            texts = tokenizer(
                sample['text'],
                return_tensors='pt', 
                padding=True, 
                truncation=True, 
                max_length=max_text_len,
            )             
            texts['input_ids'] = texts['input_ids'].to(device)
            texts['attention_mask'] = texts['attention_mask'].to(device)
            texts['token_type_ids'] = texts['token_type_ids'].to(device)
            
            image_text = [images, texts]
        else:
            video_audio = [sample['video'].to(device), sample['audio'].to(device)]
        if is_mmimdb:
            lbl = torch.stack(sample['label'], dim=1).to(device).float()
        else:
            lbl = sample['label'].to(device)
        logits, _ = model(image_text if dataset_type == 1 else video_audio)

        # Performance calculation
        if is_mmimdb:
            predicted = torch.sigmoid(logits).round().detach().cpu().numpy()
        else:
            _, predicted = torch.max(logits, 1)
        metric.update(predicted, lbl)
        
        if loss_fn is not None:
            loss = loss_fn(logits, lbl) 
            test_loss += loss.item()
            total_ce_loss += loss.item()

        iter += 1
    
    test_loss /= iter
    total_ce_loss /= iter+1

    # Calculate evaluation metrics
    all_scores = metric.compute_score(prefix='test_', enable_auroc=enable_auroc)
    if loss_fn is not None:
        all_scores.update({'test_loss': test_loss})

    # Clean Memory
    metric.reset()
    torch.cuda.empty_cache()
    gc.collect()
    
    return all_scores


@ex.automain
def main(_config):
    print("Experiment Configurations:")
    print(_config)
    fix_seeds(_config["seed"])
    setup_cudnn()

    num_workers = _config['num_workers']
    device = torch.device(_config['device'])
    num_classes = _config['class_num']
    max_text_len = _config['max_text_len']
    batch_size = _config['batch_size']
    wandb_exp_name = _config["wandb_exp_name"]
    enable_lora = _config['enable_lora']
    is_imdb = False
    if _config['exp_name'] == "finetune_mmimdb":
        is_imdb = True

    # construct missing modality info
    missing_info = {
        'ratio' : _config["missing_ratio"],
        'type' : _config["missing_type"],
        'both_ratio' : _config["both_ratio"],
        'missing_table_root': _config["missing_table_root"],
        'simulate_missing' : _config["simulate_missing"],
        'only_paired' : _config["only_paired"]
    }

    # Dataset
    dataset_type = 1 # 1 -> Image-Text, 2 -> Audio-Video
    if _config['exp_name'] == "finetune_mmimdb":
        valset = MMIMDBDataset(
            f"{_config['data_root']}/{_config['datasets'][0]}",
            _config['train_transform_keys'],
            split="test",
            image_size=_config['image_size'],
            max_text_len=_config['max_text_len'],
            draw_false_image=_config['draw_false_image'],
            draw_false_text=_config['draw_false_text'],
            image_only=_config['image_only'],
            missing_info=missing_info,
        )
    elif _config['exp_name'] == "finetune_food101":
        valset = FOOD101Dataset(
            f"{_config['data_root']}/{_config['datasets'][0]}",
            _config['train_transform_keys'],
            split="test",
            image_size=_config['image_size'],
            max_text_len=_config['max_text_len'],
            draw_false_image=_config['draw_false_image'],
            draw_false_text=_config['draw_false_text'],
            image_only=_config['image_only'],
            missing_info=missing_info,
        )
    elif _config['exp_name'] == "finetune_ks":
        valset = KineticsSound(
            split="test",
            dataset_root_dir=_config['data_dir'],
            missing_info=missing_info,
        )
        dataset_type = 2
    elif _config['exp_name'] == "finetune_ave":
        valset = AVE(
            split="test",
            dataset_root_dir=_config['data_dir'],
            missing_info=missing_info,
        )
        dataset_type = 2
    elif _config['exp_name'] == "finetune_creamad":
        valset = CREAMAD(
            split="test",
            dataset_root_dir=_config['data_dir'],
            missing_info=missing_info,
        )
        dataset_type = 2
    else:
        sys.exit("No valid experiment selected. Aborting!") 
    
    model_path = Path(_config['model_path'])
    if not model_path.exists():
        raise FileNotFoundError
    print(f"Evaluating {model_path}...")

    exp_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    eval_path = os.path.join(os.path.dirname(_config['model_path']), 'eval_{}.txt'.format(exp_time))

    model_name = _config['model_name']
    if model_name == 'bert':
        model = BertClassifier(
            num_classes, 
            max_text_len, 
            r=_config['r'],
            lora_alpha=_config['lora_alpha'],
            lora_dropout=_config['lora_dropout'],
            target_modules = _config['bert_target_modules'],
            enable_lora=enable_lora,
        )
    elif model_name == 'vit':
        model = ViTClassifier(
            num_classes, 
            max_text_len, 
            r=_config['r'],
            lora_alpha=_config['lora_alpha'],
            lora_dropout=_config['lora_dropout'],
            target_modules = _config['vit_target_modules'],
            enable_lora=enable_lora,
        )
    elif model_name == 'ast':
        model = ASTClassifier(
            num_classes, 
            max_text_len, 
            r=_config['r'],
            lora_alpha=_config['lora_alpha'],
            lora_dropout=_config['lora_dropout'],
            target_modules = _config['vit_target_modules'],
            enable_lora=enable_lora,
        )
    elif model_name == 'video':
        model = VideoClassifier(
            num_classes, 
            max_text_len, 
            r=_config['r'],
            lora_alpha=_config['lora_alpha'],
            lora_dropout=_config['lora_dropout'],
            target_modules = _config['vit_target_modules'],
            enable_lora=enable_lora,
        )
        
    msg = model.load_state_dict(torch.load(str(model_path), map_location='cpu'))
    print(msg)
    model = model.to(device)

    sampler_val = None
    valloader = DataLoader(valset, batch_size=_config['batch_size'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)
    if is_imdb:
        loss_fn = nn.BCEWithLogitsLoss()
    else:
        loss_fn = nn.CrossEntropyLoss()

    test_scores = evaluate(
        model, valloader, device, num_classes, 
        _config['max_text_len'], 
        loss_fn=loss_fn, 
        loss_fn2=None, 
        loss_fn2_weight=0.0, 
        is_mmimdb=is_imdb, 
        enable_auroc=False,
        dataset_type=dataset_type
    )
    pprint.pprint(test_scores)

    with open(eval_path, 'a+') as f:
        f.writelines(_config['model_path'])
        f.write("\n============== Eval on {} images =================\n".format(len(valset)))
        f.write("\n")
        f.write(str(_config))
        f.write("\n")
        f.write(str(_config["missing_type"]))
        f.write("\n")
        f.write(str(_config["missing_ratio"]))
        f.write("\n")
        f.write(str(test_scores))
        