import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm import tqdm
from options import prettyprint_args, get_args 
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from load_data import load_data, build_dataloader 
from utils import batch_transform
from datasets.data_utils import mean_sub, batch_mean_sub

import numpy as np
from collections import defaultdict
import data_hooks

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def evaluate(dataloader, model, restore_training=True, logger=None, leave=False, temperature=1, threshold=None):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    with torch.no_grad():
        total_loss = 0.
        test_preds = []
        test_y = []
        max_scores = []
        for data in tqdm(dataloader, leave=leave):
            X, y, _ = data #6D: 1B'CTWH
            X = X.to(device).squeeze(0)
            y = y.to(device).repeat(X.size(0))
            if X.ndim != 5:
                continue
            out = model(X)
            loss = loss_fn(out, y)
            total_loss += loss.item() / X.size(0)
            _, preds = torch.max(out.sum(axis=0).data, 0)
            soft_out = F.softmax(out / temperature, dim=-1)
            max_score = torch.max(soft_out)
            max_scores.append(max_score.item())
            test_preds.append(preds.item())
            test_y.append(y[0].item())
        test_preds = torch.Tensor(test_preds).cpu()
        test_y = torch.Tensor(test_y).cpu()
    if restore_training: model.train()
    return total_loss, test_preds, test_y, max_scores

def transform_evaluate(dataloader, aug_dataloader, model, transform, restore_training=False, data_hooks={}, logger=None, leave=True, examples=None, mean_sub=True, temperature=1):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    if type(data_hooks) != dict:
        raise ValueError("Parameter 'data_hooks' must be a mapping from string keys to functions of the form fn(X, Xt, y, metadata).")
    hook_results = defaultdict(list)
    with torch.no_grad():
        total_loss = 0.
        test_preds = []
        test_y = []
        max_scores = []
        count = 0
        for data, aug_data in tqdm(zip(dataloader, aug_dataloader), leave=leave, total=len(dataloader)):
            X, y, debug_metadata = data #6D: 1B'CTWH
            Xt, _, metadata = aug_data
            X = X.squeeze(0)
            Xt = Xt.squeeze(0)
            metadata = list(zip(*metadata))
            if mean_sub:
                Xin = batch_mean_sub(Xt.clone())
            else:
                Xin = Xt
            if Xin.ndim == 5:
                out = model(Xin.to(device))
                y = y.repeat(Xin.size(0)).to(device)
                loss = loss_fn(out, y)
                soft_out = F.softmax(out / temperature, dim=-1)
                max_score = torch.max(soft_out)
                max_scores.append(max_score.item())
                total_loss += loss.item() / X.size(1)
                _, preds = torch.max(out.sum(axis=0).data, 0)
                test_preds.append(preds.item())
            else:
                max_scores.append(-1)
                test_preds.append(-1)
            test_y.append(y[0].item())
            for key, hook in data_hooks.items():
                result = hook(X, Xt, y, metadata, transform) # compute info on the non-mean-subtracted version
                hook.update(result)
            count += 1
            if count == examples: break
        test_preds = torch.Tensor(test_preds).cpu()
        test_y = torch.Tensor(test_y).cpu()
    if restore_training: 
        model.train()
    return total_loss, test_preds, test_y, max_scores, hook_results

def attach_hooks(args):
    hooks = {}
    if args.hooks:
        if len(args.hooks) > 1:
            for hook in args.hooks: hooks[hook] = getattr(data_hooks, hook)()
        else:
            hook = args.hooks[0]
            hooks[hook] = getattr(data_hooks, hook)()
    print("=" * 70)
    print("HOOKS:")
    for fn in hooks.values(): print(fn)
    print("=" * 70)
    return hooks


def report_hook_results(hooks, results, report_acc=True):
    # universal results
    print(results['aug_info'])
    if report_acc:
        print("Accuracy:", accuracy_score(results['y'], results['preds']))
    for hook_name, hook in hooks.items(): 
        hook.report_results(results['y'], results['preds'])


def collect_incorrect(dataloader, model, id, transform=None, n=64, restore_training=False, leave=True):
    examples = []
    size = 0
    model.eval()
    with torch.no_grad():
        test_preds = []
        test_y = []
        for data in tqdm(dataloader, leave=leave):
            X, y = data
            y = y.to(device)
            if transform is not None:
                Xt = batch_transform(X.clone(), lambda x: mean_sub(transform(x))).to(device)
            else:
                Xt = X.to(device)
            out = model(Xt)
            _, preds = torch.max(out.data, 1)
            hits = X[(y != preds) & (preds == id)]
            examples.append(hits)
            test_preds.append(preds[(y != preds) & (preds == id)])
            test_y.append(y[(y != preds) & (preds == id)])
            size += hits.size(0)
            if size >= n:
                break
        test_preds = torch.cat(test_preds).cpu()
        test_y = torch.cat(test_y).cpu()
        examples = batch_transform(torch.cat(examples, dim=0), transform).cpu()
    if restore_training:
        model.train()
    return test_y, test_preds, examples


if __name__ == '__main__':
    from importlib import import_module
    args = get_args()
    args.bs = 1
    print(prettyprint_args(args))
    if not args.load_pretrained:
        raise ValueError("Pretrained model path cannot be NoneType")
    module_name, model_constructor_name = args.model_module.rsplit('.', 1)
    module_obj = import_module(module_name)
    model_type = getattr(module_obj, model_constructor_name)
    kwds = {}
    if "resnet" in model_constructor_name:
        kwds = {'num_classes': args.n_classes, 'sample_size': args.load_width, 'sample_duration': args.max_frames}
    model = model_type(**kwds)
    model = model.to(device).eval()
    print("Loaded a model!")
    print("Type:", args.model_module)
    checkpoint = torch.load(args.load_pretrained)
    try:
        checkpoint['state_dict'] = {k.split(".", 1)[1]:v for k, v in checkpoint['state_dict'].items()}
        model.load_state_dict(checkpoint['state_dict'])
    except KeyError:
        model.load_state_dict(checkpoint)
    data = load_data(args)
    _, dataloader, _,  _ = build_dataloader(data, args)
    loss, preds, y = evaluate(dataloader, model, restore_training=False)

    acc = (preds == y).sum().item() / y.size(0)
    prec, rec, f1, support = precision_recall_fscore_support(y, preds, average='macro')
    print("Accuracy:", acc)
    print("Loss:", loss)
    print("Precision:", prec)
    print("Recall:", rec)
    print("F1 (macro for n_classes > 2):", f1)
    _, counts = np.unique(preds.numpy(), return_counts=True)
    print("Support (predictions):", list(counts))
    _, counts = np.unique(y.numpy(), return_counts=True)
    print("Support (actual):", list(counts))
