import torch
import argparse
import yaml
import math
import os
import time
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from tabulate import tabulate
from torch.utils.data import DataLoader
from torch.nn import functional as F
from math import ceil
import numpy as np
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from utils.metric import Metrics
import gc
from config import ex
from utils.utils import fix_seeds, setup_cudnn, get_logger, cal_flops
from dataset.food101_dataset import FOOD101Dataset
from dataset.mmimdb_dataset import MMIMDBDataset
from models.image_text_model import ViTBertMMT
import pprint
from pytorch_metric_learning import losses

torch.multiprocessing.set_sharing_strategy('file_system')

@torch.no_grad()
def evaluate(model, dataloader, device, num_classes, max_text_len, loss_fn=None, loss_fn2=None, loss_fn2_weight=0.0, is_mmimdb=False, enable_auroc=False, enable_mt=False, text_model='bert-base-uncased'):
    # Clean Memory
    torch.cuda.empty_cache()
    gc.collect()

    metric = Metrics()

    print('Evaluating...')
    model.eval()
    n_classes = num_classes
    
    # Processor
    tokenizer = dataloader.dataset.tokenizer
    image_processor = dataloader.dataset.image_processor

    test_loss = 0.0
    total_ce_loss = 0.0
    if enable_mt:
        total_alignment_loss = 0.0
    iter = 0

    for sample in tqdm(dataloader):
        images = image_processor(sample['image'], return_tensors="pt", do_rescale=False).to(device)
        texts = tokenizer(
            sample['text'],
            return_tensors='pt', 
            padding=True, 
            truncation=True, 
            max_length=max_text_len,
        ) 
        missing_type = sample['missing_type'].to(device)
        
        texts['input_ids'] = texts['input_ids'].to(device)
        texts['attention_mask'] = texts['attention_mask'].to(device)
        if text_model == 'bert-base-uncased':
            texts['token_type_ids'] = texts['token_type_ids'].to(device)
        
        image_text = [images, texts]
        if is_mmimdb:
            lbl = torch.stack(sample['label'], dim=1).to(device).float()
        else:
            lbl = sample['label'].to(device)
        logits_fused, real_tokens, estimated_tokens, _ = model(image_text, missing_type)
        
        # Performance calculation
        if is_mmimdb:
            predicted = torch.sigmoid(logits_fused).round().detach().cpu().numpy()
        else:
            _, predicted = torch.max(logits_fused, 1)
        metric.update(predicted, lbl)
        
        if loss_fn is not None:
            loss = loss_fn(logits_fused, lbl) 
            test_loss += loss.item()
            total_ce_loss += loss.item()
        if loss_fn2 is not None:
            if enable_mt and len(real_tokens) > 0:
                align_loss = loss_fn2(torch.stack(real_tokens), torch.stack(estimated_tokens))
            else:
                align_loss = loss_fn2(torch.zeros(1, 768), torch.zeros(1, 768))

            if enable_mt:
                test_loss += loss_fn2_weight*align_loss.item()
                total_alignment_loss += loss_fn2_weight*align_loss.item()

        iter += 1
    
    test_loss /= iter
    total_ce_loss /= iter+1
    if enable_mt:
        total_alignment_loss /= iter+1

    # Calculate evaluation metrics
    all_scores = metric.compute_score(prefix='test_')
    if loss_fn is not None:
        all_scores.update({'test_loss': test_loss, 'test_ce_loss': total_ce_loss})
    if loss_fn2 is not None:
        all_scores.update({'test_ce_loss': total_ce_loss, 'test_alignment_loss': total_alignment_loss})

    # Clean Memory
    metric.reset()
    torch.cuda.empty_cache()
    gc.collect()
    
    return all_scores


@ex.automain
def main(_config):
    print("Experiment Configurations:")
    print(_config)
    fix_seeds(_config["seed"])
    setup_cudnn()

    num_workers = _config['num_workers']
    device = torch.device(_config['device'])
    num_classes = _config['class_num']
    max_text_len = _config['max_text_len']
    batch_size = _config['batch_size']
    wandb_exp_name = _config["wandb_exp_name"]
    enable_lora = _config['enable_lora']
    is_imdb = False
    enable_mt = _config['enable_mt']
    text_model = _config['text_model']
    if _config['exp_name'] == "finetune_mmimdb":
        is_imdb = True

    # construct missing modality info
    missing_info = {
        'ratio' : _config["missing_ratio"],
        'type' : _config["missing_type"],
        'both_ratio' : _config["both_ratio"],
        'missing_table_root': _config["missing_table_root"],
        'simulate_missing' : _config["simulate_missing"],
        'only_paired' : _config["only_paired"]
    }

    # Dataset
    if _config['exp_name'] == "finetune_mmimdb":
        valset = MMIMDBDataset(
            _config['data_dir'],
            _config['train_transform_keys'],
            split="test",
            image_size=_config['image_size'],
            max_text_len=_config['max_text_len'],
            draw_false_image=_config['draw_false_image'],
            draw_false_text=_config['draw_false_text'],
            image_only=False,
            missing_info=missing_info,
            enable_mt=enable_mt,
            text_model=text_model,
        )
    elif _config['exp_name'] == "finetune_food101":
        valset = FOOD101Dataset(
            _config['data_dir'],
            _config['train_transform_keys'],
            split="test",
            image_size=_config['image_size'],
            max_text_len=_config['max_text_len'],
            draw_false_image=_config['draw_false_image'],
            draw_false_text=_config['draw_false_text'],
            image_only=False,
            missing_info=missing_info,
            enable_mt=enable_mt,
            text_model=text_model,
        )
    else:
        sys.exit("No valid experiment selected. Aborting!") 
    
    model_path = Path(_config['model_path'])
    if not model_path.exists():
        raise FileNotFoundError
    print(f"Evaluating {model_path}...")

    exp_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())
    eval_path = os.path.join(os.path.dirname(_config['model_path']), 'eval_{}.txt'.format(exp_time))

    model = ViTBertMMT(
        num_classes, 
        max_text_len, 
        r=_config['r'],
        lora_alpha=_config['lora_alpha'],
        lora_dropout=_config['lora_dropout'],
        vit_target_modules = _config['vit_target_modules'],
        bert_target_modules = _config['bert_target_modules'],
        enable_lora=enable_lora,
        enable_mt=enable_mt,
        text_model=text_model,
    )
    
    state_dict = torch.load(str(model_path), map_location='cpu')
    msg = model.load_state_dict(state_dict)
    print(msg)
    model = model.to(device)

    sampler_val = None
    valloader = DataLoader(valset, batch_size=_config['batch_size'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)
    if is_imdb:
        loss_fn = nn.BCEWithLogitsLoss()
    else:
        loss_fn = nn.CrossEntropyLoss()
    alignment_loss = nn.MSELoss()
    
    test_scores = evaluate(
        model, valloader, device, num_classes, 
        _config['max_text_len'], 
        loss_fn=loss_fn, 
        loss_fn2=alignment_loss if enable_mt else None, 
        loss_fn2_weight=_config['mt_alignment_loss_weight'], 
        is_mmimdb=is_imdb, 
        enable_auroc=False,
        enable_mt=enable_mt,
        text_model=text_model,
    )
    pprint.pprint(test_scores)

    with open(eval_path, 'a+') as f:
        f.writelines(_config['model_path'])
        f.write("\n============== Eval on {} images =================\n".format(len(valset)))
        f.write("\n")
        f.write(str(_config))
        f.write("\n")
        f.write(str(_config["missing_type"]))
        f.write("\n")
        f.write(str(_config["missing_ratio"]))
        f.write("\n")
        f.write(str(test_scores))
        