import os
import json
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage

import os
import json
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage

# dumps_paths = ["/home/tyurina/local_experiments/sgd_resnet_without_bn_bz_1/",
#                "/home/tyurina/local_experiments/mvr_resnet_without_bn_bz_1/"]
# dumps_paths = ["/home/tyurina/local_experiments/sgd_resnet_without_bn_bz_64_new_stat/",
#                "/home/tyurina/local_experiments/mvr_resnet_without_bn_bz_64_new_stat/"]
# dumps_paths = ["/home/tyurina/local_experiments/sgd_resnet_without_bn_bz_64_elu/",
#                "/home/tyurina/local_experiments/mvr_resnet_without_bn_bz_64_elu/"]

# dumps_paths = ["/home/tyurina/local_experiments/gd_resnet_without_bn_elu/"]

# dumps_paths = ["/home/tyurina/local_experiments/gd_resnet_without_bn_elu/",
#                "/home/tyurina/local_experiments/gd_resnet_without_bn_elu_more_stat/"]

# dumps_paths = ["/home/tyurina/local_experiments/gd_resnet_without_bn_elu_more_stat/",
#                "/home/tyurina/local_experiments/gd_resnet_without_bn_elu_small_steps/"]

dumps_paths = ["/home/tyurina/local_experiments/gd_resnet_without_bn_elu_save_model/"]

# dumps_paths = ["/home/tyurina/local_experiments/gd_resnet_without_bn_elu_more_stat/"]

# dumps_paths = ["/home/tyurina/local_experiments/gd_two_layer_tanh/",
#                "/home/tyurina/local_experiments/holder_gd_0_75_two_layer_tanh/"]

# dumps_paths = ["/home/tyurina/local_experiments/gd_two_layer_tanh/",
#                "/home/tyurina/local_experiments/adaptive_gd_two_layer_tanh/",
#                "/home/tyurina/local_experiments/adaptive_holder_gd_0_25_two_layer_tanh/",
#                "/home/tyurina/local_experiments/adaptive_holder_gd_0_5_two_layer_tanh/"]

# dumps_paths = ["/home/tyurina/local_experiments/gd_resnet_without_bn_elu_small_steps/"]

# dumps_paths = ["/home/tyurina/local_experiments/sgd_resnet_without_bn_bz_64_elu_small_momentum/",
#                "/home/tyurina/local_experiments/mvr_resnet_without_bn_bz_64_elu_small_momentum/",
#                "/home/tyurina/local_experiments/sgd_resnet_without_bn_bz_64_elu/",
#                "/home/tyurina/local_experiments/mvr_resnet_without_bn_bz_64_elu/",
#                "/home/tyurina/local_experiments/gd_resnet_without_bn_elu/"]

# dumps_paths = ["/home/tyurina/local_experiments/sgd_resnet_without_bn_bz_64_elu_small_momentum/",
#                "/home/tyurina/local_experiments/sgd_resnet_without_bn_bz_64_elu/"]
dumps = []
for dumps_path in dumps_paths:
    files = os.listdir(dumps_path)
    for file in files:
        if 'source_folder' == file or '_tmp_' in file or '_model_' in file:
            continue
        with open(os.path.join(dumps_path, file)) as fd:
            dump = json.load(fd)
#             if dump['config']['learning_rate'] not in [1.0, 0.1, 0.05]:
#                 continue
            dump['_path'] = os.path.join(dumps_path, file)
            dump['config']['momentum'] = dump['config'].get('momentum', None)
            dumps.append(dump)
            
import sys
import torch
import torchvision
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn

sys.path.append('/home/tyurina/rsync_watch/distributed_optimization_library/code/')
from distributed_optimization_library.experiments.local_optimization_pytorch.optimize_model import StochasticModel
from distributed_optimization_library.experiments.local_optimization_pytorch.optimize_model import prepare_resnet_without_bn, parameters_to_tensor, prepare_two_layer_nn

for index, dump in enumerate(dumps):
#     if dump['config']['learning_rate'] == 0.05:
    if dump['config']['learning_rate'] == 0.001:
        break

model = prepare_resnet_without_bn(**dump['config']['resnet_params'])
# model = prepare_two_layer_nn(**dump['config']['two_layer_nn_params'])
optimize_memory = False
model.load_state_dict(torch.load(dump['_path'] + '_model_20000'))

loss_fn = torch.nn.CrossEntropyLoss()
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = torchvision.datasets.CIFAR10(root='/home/tyurina/data', train=True, download=True,
                                        transform=transform_train)

trainset = torch.utils.data.Subset(trainset, range(5000))

model_wrapper = StochasticModel(model, loss_fn, trainset, 
                                batch_size=1024, 
                                num_workers=4,
                                use_cuda=True,
                                optimize_memory=optimize_memory)


start_point = model_wrapper.current_point()
point = model_wrapper.current_point()
lr = dump['config']['learning_rate']

print('LR: ', lr)
gradient_orig = model_wrapper.gradient(point)
loss, _ = model_wrapper.last_loss_and_accuracy()
print('Norm gradient: ', np.linalg.norm(gradient_orig.cpu().numpy()))
start = 0
for index, paramter in enumerate(model_wrapper._model.parameters()):
    print('-' * 100)
    print("Index: ", index)
    print(paramter.shape)
    end = start + paramter.numel()
    gradient = torch.clone(gradient_orig)
#     steps = np.linspace(-1.0, 1.0, num=1000)
    steps = np.linspace(-2 * lr, 2 * lr, num=100)
    # steps = np.array([lr])
    losses = []
    mask = np.ones(len(gradient), dtype=np.bool)
    mask[start:end] = 0
    gradient[mask] = 0
    for step in steps:
        new_point = point - step * gradient
        new_gradient = model_wrapper.gradient(new_point)
        new_loss, _ = model_wrapper.last_loss_and_accuracy()
        losses.append(new_loss)
    start = end

    fig, ax = plt.subplots(figsize=(40, 20))
    ax.plot(steps, losses)
    ax.axvline(x=0.0)
    ax.axvline(x=lr)
    index_min = np.argmin(losses)
    print(steps[index_min])
    index_min = np.argmin(np.abs(steps - lr))
    print(steps[index_min], loss, losses[index_min])
    fig.savefig('/home/tyurina/tmp/plot_nn_more/{}.pdf'.format(index))
# print(cos_distance(to_np(gradient), to_np(new_gradient)))
