import torch
import torch.nn as nn

from utils import *
from energy import get_energy_loss
from graph import TaskGraph
from logger import Logger, VisdomLogger
from datasets import load_train_val, load_test, load_ood
from task_configs import tasks, RealityTask
from transfers import functional_transfers

from fire import Fire

import wandb
wandb.init(project="consistency", entity="hello_world")

import pdb

def main(
	loss_config="multiperceptual", mode="winrate", visualize=False,
	fast=False, batch_size=None, resume=False,
	subset_size=None, max_epochs=2000, dataaug=False, **kwargs,
):

	# CONFIG
	wandb.config.update({"loss_config":loss_config,"batch_size":batch_size,"data_aug":dataaug,"lr":3e-5,
		"n_gauss":1,"distribution":"laplace"})

	batch_size = batch_size or (4 if fast else 64)
	energy_loss = get_energy_loss(config=loss_config, mode=mode, **kwargs)

	# DATA LOADING
	train_dataset, val_dataset, train_step, val_step = load_train_val(
		energy_loss.get_tasks("train"),
		batch_size=batch_size, fast=fast,
		subset_size=subset_size,
		dataaug=dataaug,
	)
	test_set = load_test(energy_loss.get_tasks("test"))
	ood_set = load_ood(energy_loss.get_tasks("ood"))
	# ood_syn_aug_set = load_ood(energy_loss.get_tasks("ood_syn_aug"), ood_path='/scratch-data/teresa/ood_syn_distortions')
	# ood_syn_set = load_ood(energy_loss.get_tasks("ood_syn"), ood_path='/scratch-data/teresa/ood_syn_distortions2', sample=35)

	# train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)
	train = RealityTask("train", train_dataset, batch_size=batch_size, shuffle=True)      # distorted and undistorted 
	val = RealityTask("val", val_dataset, batch_size=batch_size, shuffle=True)
	test = RealityTask.from_static("test", test_set, energy_loss.get_tasks("test"))
	ood = RealityTask.from_static("ood", ood_set, [tasks.rgb,])
	# ood_syn_aug = RealityTask.from_static("ood_syn_aug", ood_syn_aug_set, [tasks.rgb,])          ## synthetic distortion images used for sig training 
	# ood_syn = RealityTask.from_static("ood_syn", ood_syn_set, [tasks.rgb,])                      ## unseen syn distortions

	# GRAPH
	realities = [train, val, test, ood]
	graph = TaskGraph(tasks=energy_loss.tasks + realities, pretrained=True, finetuned=False,
		freeze_list=energy_loss.freeze_list,
	)
	graph.compile(torch.optim.Adam, lr=3e-5, weight_decay=2e-6, amsgrad=True)
	if resume:
		graph.load_weights('/workspace/shared/results_CH_test_1/graph.pth')
		graph.optimizer.load_state_dict(torch.load('/workspace/shared/results_CH_test_1/opt.pth'))
	else:
		graph.load_weights('/scratch-data/shared_uncertainty/results_rgb2normal_depthreshadecurvimgnetl1perceps_0.1nll_sigonly_gaussblurgaussnoiseaug_0.1nll1sig100lwfloss/graph.pth', [str(('rgb', 'normal'))])


	# LOGGING
	logger = VisdomLogger("train", env=JOB)    # fake visdom logger
	logger.add_hook(lambda logger, data: logger.step(), feature="loss", freq=20)
	energy_loss.logger_hooks(logger)

	path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
	for reality_paths, reality_images in path_values.items():
		wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=0)
	pdb.set_trace()
	# TRAINING
	for epochs in range(0, max_epochs):

		logger.update("epoch", epochs)

		graph.eval()
		for _ in range(0, val_step):
			with torch.no_grad():
				val_loss, _, _ = energy_loss(graph, realities=[val])
				val_loss = sum([val_loss[loss_name] for loss_name in val_loss])
			val.step()
			logger.update("loss", val_loss)

		graph.train()
		for _ in range(0, train_step):
			train_loss, coeffs, avg_grads = energy_loss(graph, realities=[train], compute_grad_ratio=True)
			train_loss = sum([train_loss[loss_name] for loss_name in train_loss])
			graph.step(train_loss)
			train.step()
			logger.update("loss", train_loss)

		energy_loss.logger_update(logger)

		data=logger.step()
		del data['loss']
		del data['epoch']
		data = {k:v[0] for k,v in data.items()}
		wandb.log(data, step=epochs)
		wandb.log(coeffs, step=epochs)
		wandb.log(avg_grads,step=epochs)

		if epochs % 5 == 0:
			graph.save(f"{RESULTS_DIR}/graph.pth")
			torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")

		if epochs % 10 == 0:
			path_values = energy_loss.plot_paths(graph, logger, realities, prefix="")
			for reality_paths, reality_images in path_values.items():
				wandb.log({reality_paths: [wandb.Image(reality_images)]}, step=epochs+1)


	graph.save(f"{RESULTS_DIR}/graph.pth")
	torch.save(graph.optimizer.state_dict(),f"{RESULTS_DIR}/opt.pth")

if __name__ == "__main__":
	Fire(main)
