import os
from logging import WARNING
import warnings
warnings.filterwarnings("ignore")

import psutil
import ray

from logger import Logger
from utils import set_seed, save_metrics_params, update_params_from_cmdline, save_settings_to_json

import warcraft_shortest_path.data_utils as warcraft_shortest_path_data
import warcraft_shortest_path.attack_trainers as warcraft_shortest_path_trainers
#JK

import torch
from models import get_model
from warcraft_shortest_path.attack_trainers import AttackAbstractTrainer, DijkstraAttacker
from torch.nn.parameter import Parameter

from utils import optimizer_from_string
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from utils import minimum, maximum
import comb_modules.dijkstra

dataset_loaders = {
    "warcraft_shortest_path": warcraft_shortest_path_data.load_dataset
}

trainer_loaders = {
    "warcraft_shortest_path": warcraft_shortest_path_trainers.get_trainer
}



params = update_params_from_cmdline(verbose=True)
os.makedirs(params.model_dir, exist_ok=True)
dataset_loader = dataset_loaders[params.problem_type]
train_iterator, test_iterator, metadata = dataset_loader(**params.loader_params)
trainer_class = trainer_loaders[params.problem_type](params.trainer_name)
train_array = train_iterator.array


#print("train_array = ")
#print( train_array )
"""
print("type(train_array) = ")
print( type(train_array) )
print("len(train_array) = ")
print( len(train_array) )

print("type(train_array[0]) = ")
print( type(train_array[0]) )
print("len(train_array[0]) = ")
print( len(train_array[0]) )

print("type(train_array[0][0]) = ")
print( type(train_array[0][0]) )
print("len(train_array[0][0]) = ")
print( len(train_array[0][0]) )

print("type(train_array[0][1]) = ")
print( type(train_array[0][1]) )
print("len(train_array[0][1]) = ")
print( len(train_array[0][1]) )

print("type(train_array[0][2]) = ")
print( type(train_array[0][2]) )
print("len(train_array[0][2]) = ")
print( len(train_array[0][2]) )


print("train_array[0] = ")
print( train_array[0] )
"""

sample_in      = train_array[10][0]
sample_path    = train_array[10][1]
sample_weights = train_array[10][2]

target_path1 = torch.zeros(sample_path.shape)
target_path1[0,:]  = 1
target_path1[:,-1] = 1
target_path2 = torch.zeros(sample_path.shape)
target_path2[:,0]  = 1
target_path2[-1,:] = 1

true_path = sample_path
sample_path = target_path1

"""
plotimage = metadata["denormalize"]( sample_in )  #np.expand_dims(sample_in,0) )  # .reshape(96,96,3)
print("sample_in = ")
print( sample_in  )
print("plotimage = ")
print( plotimage )
input("waiting")
plot_im_arr = np.transpose(  plotimage.squeeze().astype(np.uint8), (2,1,0) )
print("sample_in.shape = ")
print( sample_in.shape  )
print("plot_im_arr.shape = ")
print( plot_im_arr.shape  )
im = Image.fromarray(plot_im_arr)
im.show()
"""


image_min = sample_in.min()
image_max = sample_in.max()

"""
print("sample_in.min() = ")
print( sample_in.min() )
print("sample_in.max() = ")
print( sample_in.max() )
input()
"""


print("sample_in min, max  = {} {}".format(sample_in.min(),sample_in.max()))

#print("sample_path = ")
#print( sample_path )

#print("sample_weights = ")
#print( sample_weights )

trainer_class = DijkstraAttacker #AttackAbstractTrainer  #trainer_loaders[params.problem_type](params.trainer_name)

LAMBDA_VAL = params["trainer_params"]["lambda_val"]
comb_modules.dijkstra.meta_holder["LAMBDA_VAL"] = LAMBDA_VAL


fast_mode = params.get("fast_mode", False)
trainer = trainer_class(
    train_iterator=train_iterator,
    test_iterator=test_iterator,
    metadata=metadata,
    fast_mode=fast_mode,
    **params.trainer_params
)
#model = get_model( model_name, out_features=trainer.metadata["output_features"], in_channels=trainer.metadata["num_channels"], arch_params=arch_params )

#print('params = ')
#print( params )

#print(type(params['trainer_params']['optimizer_params']))
#print(dict(params['trainer_params']['optimizer_params']))

state_dict = torch.load('combres1.pt', map_location=torch.device('cpu'))
trainer.model.load_state_dict(state_dict)
#self.model = model
trainer.sample = sample_in, target_path1, sample_weights   #train_array[0]   #JK
trainer.delta = Parameter(torch.zeros(sample_in.shape))   #JK
trainer.maxdelta = 0.1
##trainer.optimizer = optimizer_from_string(optimizer_name)(self.delta, **optimizer_params)
trainer.optimizer = torch.optim.SGD([trainer.delta], lr=params['trainer_params']['optimizer_params']['lr'])    #**optimizer_params)
#trainer.optimizer = torch.optim.SGD(trainer.model.parameters(), lr=params['trainer_params']['optimizer_params']['lr'])    #**optimizer_params)
trainer.image_min = image_min
trainer.image_max = image_max


loss_list = [] #JK
train_results = {}
for i in range(params.num_epochs):

    train_results = trainer.train_epoch()

    loss_list.append( train_results['train_loss'] )

    #print(train_results['train_loss'])
    #print(train_results['train_accuracy'])

print( 'true_path' )
print(  true_path )
print( 'target_path1' )
print(  target_path1 )
print( 'predicted path' )
print( train_results['path'] )
plt.plot( range(len(loss_list)), loss_list, 'b*' )
plt.ylim(0)
plt.ylabel("Hamming Loss")
plt.xlabel("PGD Iteration")
plt.title('$\delta = {}$'.format(trainer.maxdelta))
plt
plt.show()
