import torch
import copy
import torch.nn as nn
from models.heads.heads import *
from models.heads.hydra import *
from utils.train_utils import *
from torch.distributions import Categorical
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from accelerate import Accelerator
from torchmetrics import F1Score
from sklearn.metrics import f1_score
import tqdm
import numpy as np
import wandb
import yaml
import sys
import pickle
import argparse
parser = argparse.ArgumentParser("Enable/Disable Wandb")
parser.add_argument('--ri','--remove_image', action='store_true')
parser.add_argument('--rt','--remove_text', action='store_true')
parser.add_argument('--c','--config', action='store')
parser.add_argument('--sp', '--save_path', action='store')
args = parser.parse_args()

torch.manual_seed(101)


config = args.c
with open(config, 'r') as f:
    config = yaml.safe_load(f) 

torch.random.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()

_, processor = get_model_and_processor(config)
model = torch.load(config['model_checkpoint'], map_location = torch.device(0)).module
test_db = get_dataset(config["dataset"], split='test')
#_, test_db  = get_dataset(config["dataset"])
num_classes = test_db.num_classes

print(f"There are {num_gpus} gpus")
bs = 100
test_dl =  DataLoader(test_db, batch_size = bs, collate_fn = lambda x: collate_fn(x, processor, num_classes, config)) 
model.to(device)

if config["task"] == 'multilabel':
    loss_fn = torch.nn.BCEWithLogitsLoss() #built-in sigmoid
    metric = lambda logits, labels: f1_score(labels, logits, average = 'macro')
    #metric = F1Score(task = 'multilabel', threshold=0.0, num_labels = num_classes, .modul.moduleeaverage = 'macro')
elif config["task"] == 'multiclass':
    loss_fn = torch.nn.CrossEntropyLoss() #built-in softmax
    metric = lambda logits, labels: f1_score(labels, logits, average='macro')

def remove_text_input(batch):
    bs, seq_len = batch['input_ids'].shape
    batch['input_ids'] = torch.Tensor([101,102] + [0]*(seq_len - 2)).repeat(bs,1).long()
    batch['attention_mask'] = torch.Tensor([1,1] + [0]*(seq_len - 2)).repeat(bs,1).float()
    batch['token_type_ids']= torch.zeros_like(batch['token_type_ids']).long()
    return batch

def remove_image_input(batch):
    batch['pixel_values'] = torch.zeros_like(batch['pixel_values']).float()
    batch['pixel_mask'] = torch.ones_like(batch['pixel_mask']).long()
    return batch

#means_text = torch.Tensor([2.816537, 6.121535, 2.0383527, 1.4921123, 1.26161, 1.104946, 0.9650882, 0.656466, 0.9487215, 0.7087383, 0.93216366, 0.69854546, 1.3113161, 0.50449944, 0.6917074, 0.4796518, 0.79453915, 1.6297787, 0.7784183, 0.39148328, 0.6385968, 0.8599843, 0.40938494])
means_text = torch.Tensor([0.5946929, 0.69299555, 1.3461456, 0.7967005, 1.1016434, 0.9502602, 1.061284, 1.1228799, 1.0665462, 1.3603711, 0.7424356, 0.90390503, 1.1678252, 0.9362505, 1.0138079, 0.99920136, 0.870878, 0.85788625, 0.92233753, 0.98191595, 0.6288501, 0.89315426, 0.8495149, 1.0145243])
means_image = torch.Tensor([0.33643255, 0.28821167, 0.8682864, 0.45460844, 0.48021796, 0.37163508, 0.43501136, 0.4421642, 0.6459671, 0.41186664, 0.26087832, 0.53775585, 0.8507467, 0.52869505, 0.41222665, 0.32481325, 0.48035103, 0.25507605, 0.6270396, 0.5454764, 0.31979832, 0.32757083, 0.43387753, 0.92198336])


empirical_dist_text = []
empirical_dist_image = []
model.eval()
with torch.no_grad():
    total_labels = []
    total_logits = {}
    eval_losses = []
    print("starting eval")
    c = 0
    softmax = torch.nn.Softmax() 
    for batch, labels in tqdm.tqdm(test_dl):
        # Modality Dropout
        if args.rt:
            batch = remove_text_input(batch)
        if args.ri:
            batch = remove_image_input(batch)
        batch = {k:v.to(device) for k,v in batch.items()}
        labels = labels.to(device)

        # Update Model
        logits = model(**batch)
        if hasattr(logits, "logits"):
            logits = logits.logit
        if isinstance(logits, tuple):
            #def inv_sig(x):
            #    return x
            #text_logits = inv_sig(torch.clip(torch.multiply(means_text,torch.nn.functional.sigmoid(logits[0].cpu())),min=0, max=1-0.0001)).to(torch.device(0))
            #print(text_logits)
            #image_logits = inv_sig(torch.clip(torch.multiply(means_image,torch.nn.functional.sigmoid(logits[1].cpu())),min=0, max=1-0.0001)).to(torch.device(0)) 
            #print(image_logits)

            #mm_logits = logits[2]
            #empirical_dist = torch.divide(softmax(mm_logits),softmax(text_logits))
            #empirical_dist_text.append(empirical_dist.cpu().numpy())
            #empirical_dist = torch.divide(softmax(mm_logits),softmax(image_logits))
            #empirical_dist_image.append(empirical_dist.cpu().numpy())
            text_logits, image_logits, mm_logits = logits[0], logits[1], logits[2]
            logits = (text_logits, image_logits, mm_logits)
            if len(total_logits) == 0:
                for i in range(len(logits)):
                    total_logits[i] = []
            for i, v in enumerate(logits):
                total_logits[i].append(v.detach())
        else:
            if len(total_logits) == 0:
                total_logits[0] = []
            total_logits[0].append(logits.detach())
        total_labels.append(labels.detach())
    total_labels = torch.cat(total_labels).cpu().numpy()
    for key in total_logits:
        eval_logits = total_logits[key]
        eval_logits = torch.cat(eval_logits).cpu()
        if config['task'] == 'multilabel':
            preds = (eval_logits.numpy() > 0).astype('int')
        elif config['task'] == 'multiclass':
            preds = torch.argmax(eval_logits, dim = 1).numpy().astype('int')
        f1 = metric(preds, total_labels)
        print(key, f1)

    save_index = 0
    if 'save_index' in config:
        save_index = config['save_index']
    with open(config['logits_save_path'],'wb') as f:
        pickle.dump(total_logits[save_index],f)
