import os, sys, math, random, itertools
from functools import partial
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.checkpoint import checkpoint

from utils import *
from task_configs import tasks, get_task, ImageTask
from transfers import functional_transfers, finetuned_transfers, get_transfer_name, Transfer
from datasets import TaskDataset, load_train_val

from matplotlib.cm import get_cmap


# from pytorch_wavelets import DWTForward, DWTInverse

# xfm = DWTForward(J=3, mode='zero', wave='db1').cuda()

import IPython

import pdb

def get_energy_loss(
    config="", mode="winrate",
    pretrained=True, finetuned=True, **kwargs,
):
    """ Loads energy loss from config dict. """
    if isinstance(mode, str):
        mode = {
            "standard": EnergyLoss,
            "winrate": WinRateEnergyLoss,
        }[mode]
    return mode(**energy_configs[config],
        pretrained=pretrained, finetuned=finetuned, **kwargs
    )


energy_configs = {

    "baseline_reshade_sigaug": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "n(x)": [tasks.rgb, tasks.reshading],
            "n0(x)": [tasks.rgb, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.rgb, tasks.reshading_t0]],
        "losses": {
            "nll": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("n(x)", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("n(x)", "n0(x)"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "n0(x)"
                ]
            ),
        },
    },

    "baseline_edgereshade_sigaug": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "sobel(x)": [tasks.rgb, tasks.sobel_edges],
            "f(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.reshading],
            "f0(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.sobel_edges, tasks.reshading_t0],
                        [tasks.rgb, tasks.sobel_edges]],
        "losses": {
            "nll": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(sobel(x))", "f0(sobel(x))"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "sobel(x)",
                    "f(sobel(x))",
                    "f0(sobel(x))"
                ]
            ),
        },
    },


    "baseline_binreshade_sigaug": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "bin(x)": [tasks.rgb, tasks.binarized],
            "f(bin(x))": [tasks.rgb, tasks.binarized, tasks.reshading],
            "f0(bin(x))": [tasks.rgb, tasks.binarized, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.binarized, tasks.reshading_t0],
                        [tasks.rgb, tasks.binarized]],
        "losses": {
            "nll": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(bin(x))", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(bin(x))", "f0(bin(x))"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "bin(x)",
                    "f(bin(x))",
                    "f0(bin(x))"
                ]
            ),
        },
    },

    "baseline_lapreshade_sigaug": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "lap(x)": [tasks.rgb, tasks.laplace_edges],
            "f(lap(x))": [tasks.rgb, tasks.laplace_edges, tasks.reshading],
            "f0(lap(x))": [tasks.rgb, tasks.laplace_edges, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.laplace_edges, tasks.reshading_t0],
                        [tasks.rgb, tasks.laplace_edges]],
        "losses": {
            "nll": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(lap(x))", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(lap(x))", "f0(lap(x))"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "lap(x)",
                    "f(lap(x))",
                    "f0(lap(x))"
                ]
            ),
        },
    },

    "baseline_gaussreshade_sigaug": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "gauss(x)": [tasks.rgb, tasks.gauss],
            "f(gauss(x))": [tasks.rgb, tasks.gauss, tasks.reshading],
            "f0(gauss(x))": [tasks.rgb, tasks.gauss, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.gauss, tasks.reshading_t0],
                        [tasks.rgb, tasks.gauss]],
        "losses": {
            "nll": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(gauss(x))", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(gauss(x))", "f0(gauss(x))"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "gauss(x)",
                    "f(gauss(x))",
                    "f0(gauss(x))"
                ]
            ),
        },
    },

    "baseline_embossreshade_sigaug": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "emboss(x)": [tasks.rgb, tasks.emboss],
            "f(emboss(x))": [tasks.rgb, tasks.emboss, tasks.reshading],
            "f0(emboss(x))": [tasks.rgb, tasks.emboss, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.emboss, tasks.reshading_t0],
                        [tasks.rgb, tasks.emboss]],
        "losses": {
            "nll": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(emboss(x))", "f0(emboss(x))"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "emboss(x)",
                    "f(emboss(x))",
                    "f0(emboss(x))"
                ]
            ),
        },
    },


    "baseline_reshade_sigaug_calibsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "n(x)": [tasks.rgb, tasks.reshading],
            "n0(x)": [tasks.rgb, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.rgb, tasks.reshading_t0]],
        "losses": {
            "main": {
                ("train", "val", "val_noaug"): [
                    ("n(x)", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug"): [
                    ("n(x)", "n0(x)"),
                ],
            },
            "sig": {
                ("train", "val", "val_noaug"): [
                    ("n(x)", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "n0(x)"
                ]
            ),
        },
    },

    "baseline_edgereshade_sigaug_calibsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "sobel(x)": [tasks.rgb, tasks.sobel_edges],
            "f(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.reshading],
            "f0(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.sobel_edges, tasks.reshading_t0],
                        [tasks.rgb, tasks.sobel_edges]],
        "losses": {
            "main": {
                ("train", "val", "val_noaug"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug"): [
                    ("f(sobel(x))", "f0(sobel(x))"),
                ],
            },
            "sig": {
                ("train", "val", "val_noaug"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "sobel(x)",
                    "f(sobel(x))",
                    "f0(sobel(x))"
                ]
            ),
        },
    },

    "baseline_embossreshade_sigaug_calibsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "emboss(x)": [tasks.rgb, tasks.emboss],
            "f(emboss(x))": [tasks.rgb, tasks.emboss, tasks.reshading],
            "f0(emboss(x))": [tasks.rgb, tasks.emboss, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.emboss, tasks.reshading_t0],
                        [tasks.rgb, tasks.emboss]],
        "losses": {
            "main": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
            "lwf": {
                ("train", "val", "val_noaug", "val_augonly"): [
                    ("f(emboss(x))", "f0(emboss(x))"),
                ],
            },
            "sig": {
                ("train", "val", "val_noaug"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "emboss(x)",
                    "f(emboss(x))",
                    "f0(emboss(x))"
                ]
            ),
        },
    },


    "baseline_reshade_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "n(x)": [tasks.rgb, tasks.reshading],
            "n0(x)": [tasks.rgb, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.rgb, tasks.reshading_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "n0(x)"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "n0(x)"
                ]
            ),
        },
    },

    "baseline_edgereshade_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "sobel(x)": [tasks.rgb, tasks.sobel_edges],
            "f(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.reshading],
            "f0(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.sobel_edges, tasks.reshading_t0],
                        [tasks.rgb, tasks.sobel_edges]],
        "losses": {
            "main": {
                ("train_undist", "val_undist", "val_dist", "val"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_undist", "val_dist", "val"): [
                    ("f(sobel(x))", "f0(sobel(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_undist", "val_dist", "val"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "sobel(x)",
                    "f(sobel(x))",
                    "f0(sobel(x))"
                ]
            ),
        },
    },

    "baseline_embossreshade_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "emboss(x)": [tasks.rgb, tasks.emboss],
            "f(emboss(x))": [tasks.rgb, tasks.emboss, tasks.reshading],
            "f0(emboss(x))": [tasks.rgb, tasks.emboss, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.emboss, tasks.reshading_t0],
                        [tasks.rgb, tasks.emboss]],
        "losses": {
            "main": {
                ("train_undist", "val_undist", "val_dist", "val"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_undist", "val_dist", "val"): [
                    ("f(emboss(x))", "f0(emboss(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_undist", "val_dist", "val"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "emboss(x)",
                    "f(emboss(x))",
                    "f0(emboss(x))"
                ]
            ),
        },
    },

    "baseline_greyreshade_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "grey(x)": [tasks.rgb, tasks.grey],
            "f(grey(x))": [tasks.rgb, tasks.grey, tasks.reshading],
            "f0(grey(x))": [tasks.rgb, tasks.grey, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.grey, tasks.reshading_t0],
                        [tasks.rgb, tasks.grey]],
        "losses": {
            "main": {
                ("train_undist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "f0(grey(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "grey(x)",
                    "f(grey(x))",
                    "f0(grey(x))"
                ]
            ),
        },
    },

    "baseline_binreshade_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "binarized(x)": [tasks.rgb, tasks.binarized],
            "f(binarized(x))": [tasks.rgb, tasks.binarized, tasks.reshading],
            "f0(binarized(x))": [tasks.rgb, tasks.binarized, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.binarized, tasks.reshading_t0],
                        [tasks.rgb, tasks.binarized]],
        "losses": {
            "main": {
                ("train_undist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(binarized(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(binarized(x))", "f0(binarized(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(binarized(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "binarized(x)",
                    "f(binarized(x))",
                    "f0(binarized(x))"
                ]
            ),
        },
    },

    "baseline_wavreshade_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.reshading],
            "wav(x)": [tasks.rgb, tasks.wav],
            "f(wav(x))": [tasks.rgb, tasks.wav, tasks.reshading],
            "f0(wav(x))": [tasks.rgb, tasks.wav, tasks.reshading_t0],
            "reshade": [tasks.reshading],
        },
        "freeze_list": [[tasks.wav, tasks.reshading_t0],
                        [tasks.rgb, tasks.wav]],
        "losses": {
            "main": {
                ("train_undist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "f0(wav(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "wav(x)",
                    "f(wav(x))",
                    "f0(wav(x))"
                ]
            ),
        },
    },


    "baseline_normal_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "n(x)": [tasks.rgb, tasks.normal],
            "n0(x)": [tasks.rgb, tasks.normal_t0],
            "reshade": [tasks.normal],
        },
        "freeze_list": [[tasks.rgb, tasks.normal_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "n0(x)"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "n0(x)"
                ]
            ),
        },
    },

    "baseline_edgenormal_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "sobel(x)": [tasks.rgb, tasks.sobel_edges],
            "f(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.normal],
            "f0(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.normal_t0],
            "reshade": [tasks.normal],
        },
        "freeze_list": [[tasks.sobel_edges, tasks.normal_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(sobel(x))", "f0(sobel(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "sobel(x)",
                    "f(sobel(x))",
                    "f0(sobel(x))"
                ]
            ),
        },
    },

    "baseline_edgedepth_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.depth_zbuffer],
            "sobel(x)": [tasks.rgb, tasks.sobel_edges],
            "f(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.depth_zbuffer],
            "f0(sobel(x))": [tasks.rgb, tasks.sobel_edges, tasks.depth_zbuffer_t0],
            "reshade": [tasks.depth_zbuffer],
        },
        "freeze_list": [[tasks.sobel_edges, tasks.depth_zbuffer_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(sobel(x))", "f0(sobel(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(sobel(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "sobel(x)",
                    "f(sobel(x))",
                    "f0(sobel(x))"
                ]
            ),
        },
    },

    "baseline_depth_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.depth_zbuffer],
            "n(x)": [tasks.rgb, tasks.depth_zbuffer],
            "n0(x)": [tasks.rgb, tasks.depth_zbuffer_t0],
            "reshade": [tasks.depth_zbuffer],
        },
        "freeze_list": [[tasks.rgb, tasks.depth_zbuffer_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "n0(x)"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("n(x)", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "n(x)",
                    "n0(x)"
                ]
            ),
        },
    },

    "baseline_embossdepth_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.depth_zbuffer],
            "emboss(x)": [tasks.rgb, tasks.emboss],
            "f(emboss(x))": [tasks.rgb, tasks.emboss, tasks.depth_zbuffer],
            "f0(emboss(x))": [tasks.rgb, tasks.emboss, tasks.depth_zbuffer_t0],
            "reshade": [tasks.depth_zbuffer],
        },
        "freeze_list": [[tasks.emboss, tasks.depth_zbuffer_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(emboss(x))", "f0(emboss(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "emboss(x)",
                    "f(emboss(x))",
                    "f0(emboss(x))"
                ]
            ),
        },
    },

    "baseline_greydepth_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.depth_zbuffer],
            "grey(x)": [tasks.rgb, tasks.grey],
            "f(grey(x))": [tasks.rgb, tasks.grey, tasks.depth_zbuffer],
            "f0(grey(x))": [tasks.rgb, tasks.grey, tasks.depth_zbuffer_t0],
            "reshade": [tasks.depth_zbuffer],
        },
        "freeze_list": [[tasks.grey, tasks.depth_zbuffer_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "f0(grey(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "grey(x)",
                    "f(grey(x))",
                    "f0(grey(x))"
                ]
            ),
        },
    },

    "baseline_wavdepth_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.depth_zbuffer],
            # "wav(x)": [tasks.rgb, tasks.wav],
            "f(wav(x))": [tasks.rgb, tasks.wav, tasks.depth_zbuffer],
            "f0(wav(x))": [tasks.rgb, tasks.wav, tasks.depth_zbuffer_t0],
            "reshade": [tasks.depth_zbuffer],
        },
        "freeze_list": [[tasks.wav, tasks.depth_zbuffer_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "f0(wav(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    # "wav(x)",
                    "f(wav(x))",
                    "f0(wav(x))"
                ]
            ),
        },
    },

    "baseline_greynormal_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "grey(x)": [tasks.rgb, tasks.grey],
            "f(grey(x))": [tasks.rgb, tasks.grey, tasks.normal],
            "f0(grey(x))": [tasks.rgb, tasks.grey, tasks.normal_t0],
            "reshade": [tasks.normal],
        },
        "freeze_list": [[tasks.grey, tasks.normal_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "f0(grey(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(grey(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "grey(x)",
                    "f(grey(x))",
                    "f0(grey(x))"
                ]
            ),
        },
    },

    "baseline_wavnormal_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "wav(x)": [tasks.rgb, tasks.wav],
            "f(wav(x))": [tasks.rgb, tasks.wav, tasks.normal],
            "f0(wav(x))": [tasks.rgb, tasks.wav, tasks.normal_t0],
            "reshade": [tasks.normal],
        },
        "freeze_list": [[tasks.wav, tasks.normal_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "f0(wav(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(wav(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    # "wav(x)",
                    "f(wav(x))",
                    "f0(wav(x))"
                ]
            ),
        },
    },

    "baseline_embossnormal_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "emboss(x)": [tasks.rgb, tasks.emboss],
            "f(emboss(x))": [tasks.rgb, tasks.emboss, tasks.normal],
            "f0(emboss(x))": [tasks.rgb, tasks.emboss, tasks.normal_t0],
            "reshade": [tasks.normal],
        },
        "freeze_list": [[tasks.emboss, tasks.normal_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(emboss(x))", "f0(emboss(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(emboss(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "emboss(x)",
                    "f(emboss(x))",
                    "f0(emboss(x))"
                ]
            ),
        },
    },


    "baseline_binnormal_sigaug_nlllwfsig": {
        "paths": {
            "x": [tasks.rgb],
            "y^": [tasks.normal],
            "bin(x)": [tasks.rgb, tasks.binarized],
            "f(bin(x))": [tasks.rgb, tasks.binarized, tasks.normal],
            "f0(bin(x))": [tasks.rgb, tasks.binarized, tasks.normal_t0],
            "reshade": [tasks.normal],
        },
        "freeze_list": [[tasks.binarized, tasks.normal_t0]],
        "losses": {
            "main": {
                ("train_undist", "val_ooddist", "val_dist", "val"): [
                    ("f(bin(x))", "y^"),
                ],
            },
            "lwf": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(bin(x))", "f0(bin(x))"),
                ],
            },
            "sig": {
                ("train_dist", "val_ooddist", "val_dist", "val"): [
                    ("f(bin(x))", "y^"),
                ],
            },
        },
        "plots": {
            "": dict(
                size=256,
                realities=("test", "ood", "ood_syn_aug", "ood_syn"),
                paths=[
                    "x",
                    "y^",
                    "bin(x)",
                    "f(bin(x))",
                    "f0(bin(x))"
                ]
            ),
        },
    },

}



def coeff_hook(coeff):
    def fun1(grad):
        return coeff*grad.clone()
    return fun1


class EnergyLoss(object):

    def __init__(self, paths, losses, plots,
        pretrained=True, finetuned=False, freeze_list=[]
    ):

        self.paths, self.losses, self.plots = paths, losses, plots
        self.freeze_list = [str((path[0].name, path[1].name)) for path in freeze_list]
        self.metrics = {}

        self.tasks = []
        for _, loss_item in self.losses.items():
            for realities, losses in loss_item.items():
                for path1, path2 in losses:
                    self.tasks += self.paths[path1] + self.paths[path2]

        for name, config in self.plots.items():
            for path in config["paths"]:
                self.tasks += self.paths[path]
        self.tasks = list(set(self.tasks))

    def compute_paths(self, graph, reality=None, paths=None):
        path_cache = {}
        paths = paths or self.paths
        path_values = {
            name: graph.sample_path(path,
                reality=reality, use_cache=True, cache=path_cache,
            ) for name, path in paths.items()
        }
        del path_cache
        return {k: v for k, v in path_values.items() if v is not None}

    def get_tasks(self, reality):
        tasks = []
        for _, loss_item in self.losses.items():
            for realities, losses in loss_item.items():
                if reality in realities:
                    for path1, path2 in losses:
                        tasks += [self.paths[path1][0], self.paths[path2][0]]

        for name, config in self.plots.items():
            if reality in config["realities"]:
                for path in config["paths"]:
                    tasks += [self.paths[path][0]]

        return list(set(tasks))

    def __call__(self, graph, discriminator=None, realities=[], loss_types=None, batch_mean=True, use_l1=False):
        #pdb.set_trace()
        loss = {}
        for reality in realities:
            loss_dict = {}
            losses = []
            all_loss_types = set()
            for loss_type, loss_item in self.losses.items():
                all_loss_types.add(loss_type)
                loss_dict[loss_type] = []
                for realities_l, data in loss_item.items():
                    if reality.name in realities_l:
                        loss_dict[loss_type] += data
                        if loss_types is not None and loss_type in loss_types:
                            losses += data

            path_values = self.compute_paths(graph,
                paths={
                    path: self.paths[path] for path in \
                    set(path for paths in losses for path in paths)
                    },
                reality=reality)

            if reality.name not in self.metrics:
                self.metrics[reality.name] = defaultdict(list)

            mask = None
            mask = self.paths['y^'][0].build_mask(path_values['y^'], val=self.paths['y^'][0].mask_val).float()
            # path_values = {k:v*masks[:,:1].repeat(1,v.size(1),1,1) for k,v in path_values.items()}
            # pdb.set_trace()
            for loss_type, losses in sorted(loss_dict.items()):
                if loss_type not in (loss_types or all_loss_types):
                    continue
                if loss_type not in loss:
                    loss[loss_type] = 0
                for path1, path2 in losses:
                    output_task = self.paths[path1][-1]
                    if loss_type not in loss:
                        loss[loss_type] = 0
                    for path1, path2 in losses:

                        output_task = self.paths[path1][-1]

                        if loss_type=='main':
                            path_loss_nll, _ = output_task.nll(path_values[path1], path_values[path2], batch_mean=batch_mean, compute_mask=True, mask=mask)
                            loss[loss_type] += path_loss_nll
                            self.metrics[reality.name][loss_type + "_nll : "+path1 + " -> " + path2] += [path_loss_nll.mean().detach().cpu()]
                        nchannels = path_values[path1].size(1) // 2
                        if loss_type in ['main','lwf']:
                            # standard mae loss
                            path_loss, _ = output_task.norm(path_values[path1][:,:nchannels], path_values[path2][:,:nchannels], batch_mean=batch_mean, compute_mask=True, compute_mse=False, mask=mask)
                        else:
                            # calibration loss: || sigma(x)-|mu(x)-y^| ||_1
                            abs_err = (path_values["f0(wav(x))"][:,:nchannels]-path_values[path2][:,:nchannels]).abs()
                            path_loss, _ = output_task.norm(path_values[path1][:,nchannels:].exp(), abs_err, batch_mean=batch_mean, compute_mask=True, compute_mse=False, mask=mask)
                        if loss_type in ['lwf','sig']: loss[loss_type] += path_loss
                        self.metrics[reality.name][loss_type + "_mae : "+path1 + " -> " + path2] += [path_loss.mean().detach().cpu()]

                        if loss_type in ['main','lwf']:
                            path_loss, _ = output_task.norm(path_values[path1][:,:nchannels], path_values[path2][:,:nchannels], batch_mean=batch_mean, compute_mask=True, compute_mse=True, mask=mask)
                        else:
                            abs_err = (path_values["f0(wav(x))"][:,:nchannels]-path_values[path2][:,:nchannels]).abs()
                            path_loss, _ = output_task.norm(path_values[path1][:,nchannels:].exp(), abs_err, batch_mean=batch_mean, compute_mask=True, compute_mse=True, mask=mask)
                        self.metrics[reality.name][loss_type + "_mse : "+path1 + " -> " + path2] += [path_loss.mean().detach().cpu()]

        return loss

    def logger_hooks(self, logger):

        name_to_realities = defaultdict(list)
        for loss_type, loss_item in self.losses.items():
            for realities, losses in loss_item.items():
                for path1, path2 in losses:
                    name = loss_type + "_nll : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)
                    name = loss_type + "_mae : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)
                    name = loss_type + "_mse : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)

        for name, realities in name_to_realities.items():
            def jointplot(logger, data, name=name, realities=realities):
                names = [f"{reality}|{name}" for reality in realities]
                if not all(x in data for x in names):
                    return
                data = np.stack([data[x] for x in names], axis=1)
                logger.plot(data, name, opts={"legend": names})

            logger.add_hook(partial(jointplot, name=name, realities=realities), feature=f"{realities[-1]}_{name}", freq=1)


    def logger_update(self, logger):

        name_to_realities = defaultdict(list)
        for loss_type, loss_item in self.losses.items():
            for realities, losses in loss_item.items():
                for path1, path2 in losses:
                    name = loss_type + "_nll : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)
                    name = loss_type + "_mae : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)
                    name = loss_type + "_mse : "+path1 + " -> " + path2
                    name_to_realities[name] += list(realities)

        for name, realities in name_to_realities.items():
            for reality in realities:
                # IPython.embed()
                if reality not in self.metrics: continue
                if name not in self.metrics[reality]: continue
                if len(self.metrics[reality][name]) == 0: continue

                logger.update(
                    f"{reality}|{name}",
                    torch.mean(torch.stack(self.metrics[reality][name])),
                )
        self.metrics = {}

    def plot_paths(self, graph, logger, realities=[], plot_names=None, epochs=0, tr_step=0,prefix=""):

        sqrt2 = math.sqrt(2)

        cmap = get_cmap("jet")
        path_values = {}
        realities_map = {reality.name: reality for reality in realities}
        for name, config in (plot_names or self.plots.items()):
            paths = config["paths"]

            realities = config["realities"]

            for reality in realities:
                with torch.no_grad():
                    # pdb.set_trace()
                    path_values[reality] = self.compute_paths(graph, paths={path: self.paths[path] for path in paths}, reality=realities_map[reality])
                    if reality is 'test': #compute error map
                        mask_task = self.paths["y^"][-1]
                        mask = ImageTask.build_mask(path_values[reality]["y^"], val=mask_task.mask_val)
                        errors = ((path_values[reality]["y^"][:,:3]-path_values[reality]["f(wav(x))"][:,:3])**2).mean(dim=1, keepdim=True)
                        errors = (3*errors/(mask_task.variance)).clamp(min=0, max=1)
                        log_errors = torch.log(errors + 1)
                        log_errors = log_errors / log_errors.max()
                        log_errors = torch.tensor(cmap(log_errors.cpu()))[:, 0].permute((0, 3, 1, 2)).float()[:, 0:3]
                        log_errors = log_errors.clamp(min=0, max=1).to(DEVICE)
                        log_errors[~mask.expand_as(log_errors)] = 0.505
                        path_values[reality]['error']= log_errors


                    #Plot emboss scales
                    # x_h, x_l = xfm(path_values[reality]["x"])
                    # x_h = F.interpolate(x_h, size=256, mode='bilinear')
                    # x_h = x_h / 8.0
                    # x_l_0, x_l_1, x_l_2 = F.interpolate(x_l[0][:,:,0,:], size=256, mode='bilinear') , F.interpolate(x_l[1][:,:,0,:], size=256, mode='bilinear') , F.interpolate(x_l[2][:,:,0,:], size=256, mode='bilinear')
                    # path_values[reality]['x_h']= x_h
                    # path_values[reality]['x_l0']= x_l_0
                    # path_values[reality]['x_l1']= x_l_1
                    # path_values[reality]['x_l2']= x_l_2

                    nchannels = path_values[reality]['f(wav(x))'].size(1) // 2
                    path_values[reality]['f(wav(x))_m'] = path_values[reality]['f(wav(x))'][:,:nchannels]
                    path_values[reality]['f(wav(x))_s'] = path_values[reality]['f(wav(x))'][:,nchannels:].exp()*sqrt2
                    path_values[reality]['f0(wav(x))_m'] = path_values[reality]['f0(wav(x))'][:,:nchannels]
                    path_values[reality]['f0(wav(x))_s'] = path_values[reality]['f0(wav(x))'][:,nchannels:].exp()*sqrt2
                    path_values[reality] = {k:v.clamp(min=0,max=1).cpu() for k,v in path_values[reality].items()}
                    del path_values[reality]['f(wav(x))']
                    del path_values[reality]['f0(wav(x))']
                    # del path_values[reality]['emboss(x)']

        # more processing
        def reshape_img_to_rows(x_):
            downsample = lambda x: F.interpolate(x.unsqueeze(0),scale_factor=0.5,mode='bilinear').squeeze(0)
            x_list = [downsample(x_[i]) for i in range(x_.size(0))]
            x=torch.cat(x_list,dim=-1)
            return x

        all_images = {}
        for reality in realities:
            all_imgs_reality = []
            plot_name = ''
            for k in path_values[reality].keys():
                plot_name += k+'|'
                img_row = reshape_img_to_rows(path_values[reality][k])
                if img_row.size(0) == 1: img_row = img_row.repeat(3,1,1)
                all_imgs_reality.append(img_row)
            plot_name = plot_name[:-1]
            all_images[reality+'_'+plot_name] = torch.cat(all_imgs_reality,dim=-2)

        return all_images


    def __repr__(self):
        return str(self.losses)


class WinRateEnergyLoss(EnergyLoss):

    def __init__(self, *args, **kwargs):
        self.k = kwargs.pop('k', 3)
        self.random_select = kwargs.pop('random_select', False)
        self.running_stats = {}
        self.target_task = kwargs['paths']['y^'][0].name

        super().__init__(*args, **kwargs)


    def __call__(self, graph, discriminator=None, realities=[], loss_types=None):

        loss_types = ["main","lwf","sig"]
        loss_dict = super().__call__(graph, discriminator=discriminator, realities=realities, loss_types=loss_types, batch_mean=False)
        if realities[0].name == 'train_undist':
            loss_dict.pop("lwf")
            loss_dict.pop("sig")
            loss_dict["main"] = loss_dict["main"].mean() * 0.1
        elif realities[0].name == 'train_dist':
            loss_dict.pop("main")
            loss_dict["lwf"] = loss_dict["lwf"].mean() * 100.0
            loss_dict["sig"] = loss_dict["sig"].mean()

        return loss_dict

    def logger_update(self, logger):
        super().logger_update(logger)


