import argparse

import os
import copy
import statistics

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch.nn.functional as F

from utils import add_artificial_bias, invert_masks, create_balanced_subset, dilate_masks_torch
from x_resnet import xfixup_resnet50

parser = argparse.ArgumentParser(description='Bias experiment')
parser.add_argument('--data_dir', metavar='DIR', default='/datasets/imagenet/',
                    help='path to dataset')
parser.add_argument('--store_dir', metavar='DIR', default='/models/',
                    help='path to stored models')
parser.add_argument('--model', required=True,
                    choices=['default', 'presence_debias', 'presence_absence_debias'],
                    help='model mode')
parser.add_argument('--batch_size', default=128, type=int,
                    help='batch size')

args = parser.parse_args()

# ----- Paths -----
DATA_DIR = args.data_dir
STORE_DIR = args.store_dir
BATCH_SIZE = args.batch_size
NUM_EPOCHS = 20
NUM_WORKERS = 4
RUNS = 5
DEVICE = 'cuda:0'
MODE = args.model # 'default', 'presence_debias', 'presence_absence_debias'
train_wo_bias = False # set true if no training bias should be available

random_seed = 0
torch.manual_seed(random_seed)

# ----- ImageNet Normalization -----
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

# ----- Transforms -----
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
    ]),
}

# ----- Datasets and Loaders -----
image_datasets = {
    x: datasets.ImageFolder(os.path.join(DATA_DIR, x), transform=data_transforms[x])
    for x in ['train', 'val', 'test']
}

#print(len(image_datasets['train']))
for x in ['train', 'val', 'test']:
    image_datasets[x] = create_balanced_subset(image_datasets[x], targets=image_datasets[x].targets)

dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=(x=='train'), num_workers=NUM_WORKERS)
    for x in ['train', 'val', 'test']
}

class_names = image_datasets['train'].dataset.classes  # ['benign', 'malignant']

# ----- Training Loop -----
def evaluate_model(model, dataloaders, phase):
    
    class_bias_train = {
        0: 1.0,  # 100% of benign get a patch
        1: 0.0   # 0% of malignant get a patch
    }

    class_bias_inverse = {
        0: 0.0,  # 0% of benign get a patch
        1: 1.0   # 100% of malignant get a patch
    }

    class_bias_off = {
        0: 0.0,  # 0% of benign get a patch
        1: 0.0   # 0% of malignant get a patch
    }
    
    
    print('Starting', phase)
    model.eval()
    running_acc_class_0 = []
    running_acc_class_1 = []
    running_acc_average = []
    running_attribution = []

    for inputs, labels in dataloaders[phase.split("_")[0]]:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        labels_one_hot = F.one_hot(labels, num_classes=2).float()

        # Calculate class-wise accuracy
        class_0_mask = labels.int() == 0
        class_1_mask = labels.int() == 1

        if phase == 'val_trainbias':
            inputs, patch_segmentation = add_artificial_bias(inputs, labels, class_bias_train)
        if phase == 'val_inversebias':
            inputs, patch_segmentation = add_artificial_bias(inputs, labels, class_bias_inverse)
        if phase == 'val_nobias':
            inputs, patch_segmentation = add_artificial_bias(inputs, labels, class_bias_off)
        
        patch_segmentation = dilate_masks_torch(patch_segmentation)
        

    
        inputs.requires_grad = True
        outputs = model(inputs)
        
        masks_flipped = invert_masks(patch_segmentation) # flip all mask for energy loss. However, if a mask is 0 everywhere it stays zero to not get exploding attributions.
        has_mask = patch_segmentation.sum(dim=(1,2,3)) != 0
        target_outputs = torch.gather(outputs, 1, labels.unsqueeze(-1))
        gradients = torch.autograd.grad(torch.unbind(target_outputs), inputs, create_graph=True)[0] # set to false if attribution is only used for evaluation
        gradients = inputs * gradients
        attribution_inside1 = ((gradients[has_mask].abs().sum(dim=1, keepdim=True) * patch_segmentation[has_mask]).sum()) #/ (patch_segmentation.sum() + 1e-5) # normalize by number of active masks
        attribution_outside1 = ((gradients[has_mask].abs().sum(dim=1, keepdim=True) * masks_flipped[has_mask]).sum()) #/ (masks_flipped.sum() + 1e-5)
        loss_attribution_prior1 = attribution_inside1 / (attribution_outside1 + attribution_inside1 + 1e-5)
        
        loss_attribution_prior = loss_attribution_prior1.detach().item()

        preds = torch.argmax(outputs, dim=1)
        acc_class_0 = (preds[class_0_mask] == labels[class_0_mask]).float().mean().item()
        acc_class_1 = (preds[class_1_mask] == labels[class_1_mask]).float().mean().item()
        acc_average = (acc_class_0 + acc_class_1) / 2

        running_acc_class_0.append(acc_class_0)
        running_acc_class_1.append(acc_class_1)
        running_acc_average.append(acc_average)
        running_attribution.append(loss_attribution_prior)

    running_acc_class_0 = sum(running_acc_class_0) / len(running_acc_class_0)
    running_acc_class_1 = sum(running_acc_class_1) / len(running_acc_class_1)
    running_acc_average = sum(running_acc_average) / len(running_acc_average)
    running_attribution = sum(running_attribution) / len(running_attribution)

    print(f"{phase.capitalize()} Acc Class 1: {running_acc_class_0:.4f} Acc Class 2: {running_acc_class_1:.4f} Acc Avg: {running_acc_average:.4f} Attr: {running_attribution:.4f}")
    return running_acc_class_0, running_acc_class_1, running_acc_average, running_attribution

# ----- Train -----
acc_class_0_trainbias = []
acc_class_1_trainbias = []
acc_avg_trainbias = []
attr_trainbias = []


acc_class_0_inversebias = []
acc_class_1_inversebias = []
acc_avg_inversebias = []
attr_inversebias = []

acc_class_0_nobias = []
acc_class_1_nobias = []
acc_avg_nobias = []
attr_nobias = []

for seed in range(RUNS): # different training runs

    torch.manual_seed(seed)

    # ----- Model -----
    model = xfixup_resnet50()
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)  # binary classification

    if train_wo_bias:
        state_dict = torch.load(os.path.join(STORE_DIR, f'model_no_train_bias_seed{seed}_mode_' + MODE + '.pth'))
    else:
        state_dict = torch.load(os.path.join(STORE_DIR, f'model_seed{seed}_mode_' + MODE + '.pth'))
    model.load_state_dict(state_dict)


    model = model.to(DEVICE)

    acc_class_0, acc_class_1, acc_average, attr = evaluate_model(model, dataloaders, 'val_trainbias')
    acc_class_0_trainbias.append(acc_class_0)
    acc_class_1_trainbias.append(acc_class_1)
    acc_avg_trainbias.append(acc_average)
    attr_trainbias.append(attr)

    acc_class_0, acc_class_1, acc_average, attr = evaluate_model(model, dataloaders, 'val_inversebias')
    acc_class_0_inversebias.append(acc_class_0)
    acc_class_1_inversebias.append(acc_class_1)
    acc_avg_inversebias.append(acc_average)
    attr_inversebias.append(attr)

    acc_class_0, acc_class_1, acc_average, attr = evaluate_model(model, dataloaders, 'val_nobias')
    acc_class_0_nobias.append(acc_class_0)
    acc_class_1_nobias.append(acc_class_1)
    acc_avg_nobias.append(acc_average)
    attr_nobias.append(attr)

print('Trainbias:')
print('Class 1 mean:', statistics.mean(acc_class_0_trainbias))
print('Class 1 std:', statistics.stdev(acc_class_0_trainbias))
print('Class 1 max:', max(acc_class_0_trainbias))
print('Class 2 mean:', statistics.mean(acc_class_1_trainbias))
print('Class 2 std:', statistics.stdev(acc_class_1_trainbias))
print('Class 2 max:', max(acc_class_1_trainbias))
print('Average mean:', statistics.mean(acc_avg_trainbias))
print('Average std:', statistics.stdev(acc_avg_trainbias))
print('Average max:', max(acc_avg_trainbias))
print('Attr mean:', statistics.mean(attr_trainbias))
print('Attr std:', statistics.stdev(attr_trainbias))

print('Inversebias:')
print('Class 1 mean:', statistics.mean(acc_class_0_inversebias))
print('Class 1 std:', statistics.stdev(acc_class_0_inversebias))
print('Class 1 max:', max(acc_class_0_inversebias))
print('Class 2 mean:', statistics.mean(acc_class_1_inversebias))
print('Class 2 std:', statistics.stdev(acc_class_1_inversebias))
print('Class 2 max:', max(acc_class_1_inversebias))
print('Average mean:', statistics.mean(acc_avg_inversebias))
print('Average std:', statistics.stdev(acc_avg_inversebias))
print('Average max:', max(acc_avg_inversebias))
print('Attr mean:', statistics.mean(attr_inversebias))
print('Attr std:', statistics.stdev(attr_inversebias))

print('Nobias:')
print('Class 1 mean:', statistics.mean(acc_class_0_nobias))
print('Class 1 std:', statistics.stdev(acc_class_0_nobias))
print('Class 1 max:', max(acc_class_0_nobias))
print('Class 2 mean:', statistics.mean(acc_class_1_nobias))
print('Class 2 std:', statistics.stdev(acc_class_1_nobias))
print('Class 2 max:', max(acc_class_1_nobias))
print('Average mean:', statistics.mean(acc_avg_nobias))
print('Average std:', statistics.stdev(acc_avg_nobias))
print('Average max:', max(acc_avg_nobias))
print('Attr mean:', statistics.mean(attr_nobias))
print('Attr std:', statistics.stdev(attr_nobias))
