import socket
import sys
import torch
import torch.nn as nn
import tqdm
from time import time
import torch.nn.functional as F
from TRADES.trades import trades_loss
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from tikzplotlib import save as tikz_save

from src.AIDomains.wrappers import propagate_abs
from src.AIDomains.zonotope import HybridZonotope
from src.adv_attack import adv_whitebox
from src.regularization import compute_bound_reg, compute_IBP_reg

  
def get_loss_FN(args):
    if args.loss_fn == "CE":
        loss_FN = nn.CrossEntropyLoss(reduction="none")
    elif args.loss_fn == "PT1":
        def loss_FN(pred, y):
            return F.cross_entropy(pred,y, reduction="none") + args.pt1_e * (1 - torch.gather(F.softmax(pred,1),1,y.unsqueeze(1))).squeeze(1)
    else:
        assert False, f"Loss function {args.loss_fn} is unknown."
    return loss_FN


def compute_regularization(args, net_abs, data, adex, eps, tau, max_tau, data_range):
    reg = torch.zeros(1, device=data.device)
    if args.cert_reg == "bound_reg" and tau < max_tau:
        if eps == 0.0:
            eps_reg = args.min_eps_reg
            data_abs = HybridZonotope.construct_from_noise(x=data, eps=eps_reg, domain="box", data_range=data_range)
            net_abs.reset_bounds()
            net_abs(data_abs)
        reg += compute_bound_reg(net_abs, eps, args.eps_end, reg_lambda=args.reg_lambda)
    elif args.cert_reg == "ibp_reg" and eps > 0.0:
        bs = data.shape[0]
        if args.box_attack == "concrete_attack":
            curr_eps = eps * 0.05  # TODO add tau for reg to args
            large_box = HybridZonotope.construct_from_noise(x=data, eps=eps, domain="box",
                                                            data_range=data_range)
            lb_large_box, ub_large_box = large_box.concretize()
            curr_midpoints = torch.clamp(adex, lb_large_box + curr_eps, ub_large_box - curr_eps)
            tiny_box = HybridZonotope.construct_from_noise(x=curr_midpoints, eps=curr_eps, domain="box",
                                                           data_range=data_range)
            net_abs(tiny_box)  # TODO not naive box but use prop?
        reg += compute_IBP_reg(net_abs, bs, args.reg_lambda)

    if args.l1 is not None:
        reg += args.l1 * sum([x.abs().sum() for x in net_abs.parameters()])

    return reg


def get_epsilon(args, eps_test, max_tau, lambda_scheduler, eps_scheduler, scheduler_index, train):
    if train:
        eps = eps_scheduler.getcurrent(scheduler_index)
    else:
        eps = eps_test

    if args.start_anneal_lambda is not None:
        lambda_ratio = lambda_scheduler.getcurrent(scheduler_index)
    else:
        lambda_ratio = args.lambda_ratio
    tau = lambda_ratio * eps

    if args.start_sound:  # while the full region is smaller than the final small region use the full region (during annealing)
        tau = min(max_tau, eps)

    return eps, tau


def get_propagation_region(args, net_abs, data, target, train, eps, tau, data_range, adv_step_size, adv_steps, dimwise_scaling):
    adex = None
    if train:
        if args.bn_mode_attack == "eval":
            net_abs.eval()  # use eval mode of BN for attack

        if args.box_attack == "pgd_concrete":
            net_abs.set_use_old_train_stats(True)
            adex, data_abs = adv_whitebox(net_abs, data, target, tau, eps, n_steps=adv_steps, step_size=adv_step_size,
                                          data_range=data_range, loss_function=args.box_attack_loss_fn, ODI_num_steps=0,
                                          restarts=1, train=True, dimwise_scaling=dimwise_scaling)

            net_abs.set_use_old_train_stats(False)
        elif args.box_attack == "centre":
            adex = data
            data_abs = HybridZonotope.construct_from_noise(x=data, eps=tau, domain="box", data_range=data_range,
                                                           dtype=data.dtype)
        else:
            assert False, f"box_attack: {args.box_attack} is unknown!"

        net_abs.train()

        if args.use_shrinking_box:
            shrinking_domain = args.shrinking_method + args.shrinking_relu_state
            data_abs.domain = shrinking_domain
            data_abs.c = args.shrinking_ratio
        if args.adv_bn:
            net_abs[0].set_track_running_stats(track_running_stats=False)
            midpoints = data_abs.get_head()
            net_abs(midpoints)
            net_abs[0].set_track_running_stats(track_running_stats=True)
    else:
        data_abs = HybridZonotope.construct_from_noise(x=data, eps=eps, domain="box", data_range=data_range)

    if args.bn and "concrete" in args.box_attack and train:
        net_abs[0].set_track_running_stats(track_running_stats=False)
        net_abs(data)
        net_abs[0].set_track_running_stats(track_running_stats=True)

    return data_abs, adex


def train_net(net_abs, epoch, train, args, data_loader, input_dim, data_range, eps_test, use_cuda, adv_steps_scheduler, eps_scheduler,
          clip_norm_scheduler=None, lambda_scheduler=None, kappa_scheduler=None, writer=None):
    
    # get epoch parameters from schedules
    if args.adv_end_steps is None:
        adv_steps = args.adv_start_steps
    else:
        adv_steps = int(args.adv_start_steps + (args.adv_end_steps - args.adv_start_steps) * adv_steps_scheduler.getcurrent(epoch))
    if args.adv_step_size_end is None:
        adv_step_size = args.adv_step_size
    else:
        adv_step_size = args.adv_step_size + (args.adv_step_size_end - args.adv_step_size) * adv_steps_scheduler.getcurrent(epoch)

    if args.end_clip_norm is not None:
        clip_norm = clip_norm_scheduler.getcurrent(epoch)
    else:
        clip_norm = args.clip_norm

    max_tau = args.eps_end * max(args.lambda_ratio, args.end_lambda_ratio)

    # Set up logging
    n_samples = 0
    nat_ok, abs_tau_ok, abs_eps_ok = 0, 0, 0
    loss_total, robust_tau_loss_total, robust_eps_loss_total, normal_loss_total, reg_loss_total = 0, 0, 0, 0, 0

    time_start = time()
    loss_FN = get_loss_FN(args)
    print(net_abs)
    net_abs.eval()
    net_abs.set_dim(torch.rand((data_loader.batch_size, *input_dim), device="cuda" if use_cuda else "cpu"))
    if train:
        net_abs.train()
    else:
        net_abs.eval()

    pbar = tqdm.tqdm(data_loader)
    print_signal1 = []
    print_noise1 = []
    print_signal2 = []
    print_noise2 = []
    print_signal3 = []
    print_noise3 = []
    print_signal4 = []
    print_noise4 = []
    print_signal5 = []
    print_noise5 = []
    print_signal6 = []
    print_noise6 = []
    for batch_idx, (data, target, index) in enumerate(pbar):
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        net_abs.reset_bounds()

        # Get batch parameters
        scheduler_index = epoch * len(data_loader) + batch_idx
        eps, tau = get_epsilon(args, eps_test, max_tau, lambda_scheduler, eps_scheduler, scheduler_index, train)
        kappa = kappa_scheduler.getcurrent(scheduler_index)

        # net_abs.optimizer.zero_grad()

        out_normal = net_abs(data)
        adex = None

        if not train or (kappa < 1.0 and tau > 0.0):
            # abstract propagation is needed for training or testing
            data_abs_tau, adex = get_propagation_region(args, net_abs, data, target, train, eps, tau, data_range,
                                                        adv_step_size, adv_steps, args.dimwise_scaling)
            net_abs.reset_bounds()
            out_abs, pseudo_labels = propagate_abs(net_abs, args.loss_domain, data_abs_tau, target)
            robust_loss = loss_FN(out_abs, pseudo_labels).mean()

            abs_tau_ok += torch.eq(out_abs.argmax(1), pseudo_labels).sum()
 
            optimizer1 = optim.SGD(net_abs.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
            
            snr_loss, batch_metrics, print_signal_batch, print_noise_batch = trades_loss(model=net_abs,
                        x_natural=data,
                        y=target,
                        optimizer=optimizer1,
                        step_size=args.step_size,
                        epsilon=args.epsilon,
                        perturb_steps=args.num_steps,
                        beta=args.beta,
                        args=args)
            #print('**********************', print_signal_batch[1])
            print_signal1.append(print_signal_batch[0].cpu())
            print_noise1.append(print_noise_batch[0].cpu())
            print_signal2.append(print_signal_batch[1].cpu())
            print_noise2.append(print_noise_batch[1].cpu())
            print_signal3.append(print_signal_batch[2].cpu())
            print_noise3.append(print_noise_batch[2].cpu())
            print_signal4.append(print_signal_batch[3].cpu())
            print_noise4.append(print_noise_batch[3].cpu())
            print_signal5.append(print_signal_batch[4].cpu())
            print_noise5.append(print_noise_batch[4].cpu())
            print_signal6.append(print_signal_batch[5].cpu())
            print_noise6.append(print_noise_batch[5].cpu())

        elif train and args.box_attack == "concrete_attack" and kappa < 1.0 and eps > 0.0:
            # adversarial loss for training
            
            if args.bn_mode_attack == "eval":
                net_abs.eval()
            else:
                net_abs.set_use_old_train_stats(True)
            adex, _ = adv_whitebox(net_abs, data, target, 0.0, eps, n_steps=adv_steps, step_size=adv_step_size,
                                   data_range=data_range, loss_function=args.box_attack_loss_fn, ODI_num_steps=0,
                                   restarts=1, train=True)

            if args.bn_mode_attack == "eval":
                # set status back to train
                net_abs.train()
                net_abs[0].set_track_running_stats(track_running_stats=False)
                out_adex = net_abs(adex)
                out_normal = net_abs(data)
                net_abs[0].set_track_running_stats(track_running_stats=True)
            else:
                out_adex = net_abs(adex)
                net_abs.set_use_old_train_stats(False)
            robust_loss = loss_FN(out_adex, target).mean()
            abs_tau_ok += torch.eq(out_adex.argmax(1), target).sum()
        else:
            robust_loss = torch.tensor(0.0).cuda()

        normal_loss = loss_FN(out_normal, target).mean()
        nat_ok += torch.eq(out_normal.argmax(1), target).sum()
        #print('--------------------------', out_normal.argmax(1), target)

        if train:
            net_abs.optimizer.zero_grad()

            reg = compute_regularization(args, net_abs, data, adex, eps, tau, max_tau, data_range)

            robust_loss_scaled = (1 - kappa) * robust_loss
            normal_loss_scaled = kappa * normal_loss
            #########
            
            
            if adex != None:
                optimizer1 = optim.SGD(net_abs.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
                #print('********************',adex)
                
                snr_loss, batch_metrics, print_signal_batch, print_noise_batch = trades_loss(model=net_abs,
                           x_natural=data,
                           y=target,
                           optimizer=optimizer1,
                           step_size=args.step_size,
                           epsilon=args.epsilon,
                           perturb_steps=args.num_steps,
                           beta=args.beta,
                           args=args)
                robust_loss_scaled = robust_loss_scaled + snr_loss      
            #########
            
            
            


            combined_loss = robust_loss_scaled + normal_loss_scaled + reg

            if args.clip_robust_gradient and robust_loss > 0.0:
                # clip only the robust loss
                robust_loss_scaled.backward()
                torch.nn.utils.clip_grad_norm_(net_abs.parameters(), clip_norm)
                (normal_loss_scaled + reg).backward()
            else:
                combined_loss.backward()
                if args.clip_combined_gradient is not None:
                    # clip both losses
                    torch.nn.utils.clip_grad_norm_(net_abs.parameters(), clip_norm)

            net_abs.optimizer.step()
            # torch.cuda.synchronize()
        else:
            combined_loss = (1 - kappa) * robust_loss + kappa * normal_loss
            reg = torch.tensor(0)

        time_epoch = time() - time_start

        reg_loss_total += reg.detach()
        #print('------------------', kappa, tau)
        robust_tau_loss_total += robust_loss.detach()
        normal_loss_total += normal_loss.detach()
        loss_total += combined_loss.detach()
        n_samples += target.size(0)

        description_str = f"[{epoch}:{batch_idx}:{'train' if train else 'test'}]: eps = [{tau:.6f}:{eps:.6f}], kappa={kappa:.3f}, loss nat: {normal_loss_total / (batch_idx + 1):.4f}, loss abs: {robust_tau_loss_total / (batch_idx + 1):.4f}, acc_nat={nat_ok / n_samples:.4f}, acc_abs={abs_tau_ok / n_samples:.4f}"
        pbar.set_description(description_str)
        pbar.refresh()

    ### Print such that logging picks it up
    #print(description_str)
    #print('##############', print_signal1[1], print_noise1[1])
    sns.set()
    plt.figure(figsize=(12,6), dpi=1000)
    plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = "Times New Roman"
    #print(len(print_signal1))
    plot_signal1 = np.zeros((len(print_signal1)-1)*len(print_signal1[0])+len(print_signal1[-1]))
    plot_noise1 = np.zeros((len(print_noise1)-1)*len(print_noise1[0])+len(print_noise1[-1]))
    print(plot_signal1.shape)
    for i in range(len(print_signal1)):
        plot_signal1[i*len(print_signal1[0]):min((i+1)*len(print_signal1[0]),i*len(print_signal1[0])+len(print_signal1[i]))]=np.log(print_signal1[i])
        plot_noise1[i*len(print_noise1[0]):min((i+1)*len(print_noise1[0]),i*len(print_noise1[0])+len(print_noise1[i]))]=np.log(print_noise1[i])
    plt.subplot(231)
    data = pd.DataFrame({'x': plot_signal1, 'y': plot_noise1})
    sns.scatterplot(x='x', y='y', data=data, marker='.')
    if args.save_plot:
        np.save('plot_signal1.npy', plot_signal1)
        np.save('plot_noise1.npy', plot_noise1)
    if args.load_plot:
        plot_signal_temp = np.load('plot_signal1.npy')
        plot_noise_temp = np.load('plot_noise1.npy')
        np.save('plot_signal1_ours.npy', plot_signal1)
        np.save('plot_noise1_ours.npy', plot_noise1)
        data = pd.DataFrame({'x': plot_signal_temp, 'y': plot_noise_temp})
        sns.scatterplot(x='x', y='y', data=data, marker='.')
        x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
        y = x
        data = pd.DataFrame({'x': x, 'y': y})
        sns.lineplot(x='x', y='y', data=data, color='red')
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)

    #print(print_signal2.shape)
    plot_signal2 = np.zeros((len(print_signal2)-1)*len(print_signal2[0])+len(print_signal2[-1]))
    plot_noise2 = np.zeros((len(print_noise2)-1)*len(print_noise2[0])+len(print_noise2[-1]))
    print(plot_signal2.shape)
    for i in range(len(print_signal2)):
        plot_signal2[i*len(print_signal2[0]):min((i+1)*len(print_signal2[0]),i*len(print_signal2[0])+len(print_signal2[i]))]=np.log(print_signal2[i])
        plot_noise2[i*len(print_noise2[0]):min((i+1)*len(print_noise2[0]),i*len(print_noise2[0])+len(print_noise2[i]))]=np.log(print_noise2[i])
    plt.subplot(232)
    data = pd.DataFrame({'x': plot_signal2, 'y': plot_noise2})
    sns.scatterplot(x='x', y='y', data=data, marker='.')
    if args.save_plot:
        np.save('plot_signal2.npy', plot_signal2)
        np.save('plot_noise2.npy', plot_noise2)
    if args.load_plot:
        plot_signal_temp = np.load('plot_signal2.npy')
        plot_noise_temp = np.load('plot_noise2.npy')
        np.save('plot_signal2_ours.npy', plot_signal2)
        np.save('plot_noise2_ours.npy', plot_noise2)
        data = pd.DataFrame({'x': plot_signal_temp, 'y': plot_noise_temp})
        sns.scatterplot(x='x', y='y', data=data, marker='.')
        x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
        y = x
        data = pd.DataFrame({'x': x, 'y': y})
        sns.lineplot(x='x', y='y', data=data, color='red')
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)

    #print(print_signal3.shape)
    plot_signal3 = np.zeros((len(print_signal3)-1)*len(print_signal3[0])+len(print_signal3[-1]))
    plot_noise3 = np.zeros((len(print_noise3)-1)*len(print_noise3[0])+len(print_noise3[-1]))
    print(plot_signal3.shape)
    for i in range(len(print_signal3)):
        plot_signal3[i*len(print_signal3[0]):min((i+1)*len(print_signal3[0]),i*len(print_signal3[0])+len(print_signal3[i]))]=np.log(print_signal3[i])
        plot_noise3[i*len(print_noise3[0]):min((i+1)*len(print_noise3[0]),i*len(print_noise3[0])+len(print_noise3[i]))]=np.log(print_noise3[i])
    plt.subplot(233)
    data = pd.DataFrame({'x': plot_signal3, 'y': plot_noise3})
    sns.scatterplot(x='x', y='y', data=data, marker='.')
    if args.save_plot:
        np.save('plot_signal3.npy', plot_signal3)
        np.save('plot_noise3.npy', plot_noise3)
    if args.load_plot:
        plot_signal_temp = np.load('plot_signal3.npy')
        plot_noise_temp = np.load('plot_noise3.npy')
        np.save('plot_signal3_ours.npy', plot_signal3)
        np.save('plot_noise3_ours.npy', plot_noise3)
        data = pd.DataFrame({'x': plot_signal_temp, 'y': plot_noise_temp})
        sns.scatterplot(x='x', y='y', data=data, marker='.')
        x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
        y = x
        data = pd.DataFrame({'x': x, 'y': y})
        sns.lineplot(x='x', y='y', data=data, color='red')
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)

    #print(print_signal4.shape)
    plot_signal4 = np.zeros((len(print_signal4)-1)*len(print_signal4[0])+len(print_signal4[-1]))
    plot_noise4 = np.zeros((len(print_noise4)-1)*len(print_noise4[0])+len(print_noise4[-1]))
    print(plot_signal4.shape)
    for i in range(len(print_signal4)):
        plot_signal4[i*len(print_signal4[0]):min((i+1)*len(print_signal4[0]),i*len(print_signal4[0])+len(print_signal4[i]))]=np.log(print_signal4[i])
        plot_noise4[i*len(print_noise4[0]):min((i+1)*len(print_noise4[0]),i*len(print_noise4[0])+len(print_noise4[i]))]=np.log(print_noise4[i])
    plt.subplot(234)
    data = pd.DataFrame({'x': plot_signal4, 'y': plot_noise4})
    sns.scatterplot(x='x', y='y', data=data, marker='.')
    if args.save_plot:
        np.save('plot_signal4.npy', plot_signal4)
        np.save('plot_noise4.npy', plot_noise4)
    if args.load_plot:
        plot_signal_temp = np.load('plot_signal4.npy')
        plot_noise_temp = np.load('plot_noise4.npy')
        np.save('plot_signal4_ours.npy', plot_signal4)
        np.save('plot_noise4_ours.npy', plot_noise4)
        data = pd.DataFrame({'x': plot_signal_temp, 'y': plot_noise_temp})
        sns.scatterplot(x='x', y='y', data=data, marker='.')
        x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
        y = x
        data = pd.DataFrame({'x': x, 'y': y})
        sns.lineplot(x='x', y='y', data=data, color='red')
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)

    #print(print_signal5.shape)
    plot_signal5 = np.zeros((len(print_signal5)-1)*len(print_signal5[0])+len(print_signal5[-1]))
    plot_noise5 = np.zeros((len(print_noise5)-1)*len(print_noise5[0])+len(print_noise5[-1]))
    print(plot_signal5.shape)
    for i in range(len(print_signal5)):
        plot_signal5[i*len(print_signal5[0]):min((i+1)*len(print_signal5[0]),i*len(print_signal5[0])+len(print_signal5[i]))]=np.log(print_signal5[i])
        plot_noise5[i*len(print_noise5[0]):min((i+1)*len(print_noise5[0]),i*len(print_noise5[0])+len(print_noise5[i]))]=np.log(print_noise5[i])
    plt.subplot(235)
    data = pd.DataFrame({'x': plot_signal5, 'y': plot_noise5})
    sns.scatterplot(x='x', y='y', data=data, marker='.')
    if args.save_plot:
        np.save('plot_signal5.npy', plot_signal1)
        np.save('plot_noise5.npy', plot_noise1)
    if args.load_plot:
        plot_signal_temp = np.load('plot_signal5.npy')
        plot_noise_temp = np.load('plot_noise5.npy')
        np.save('plot_signal5_ours.npy', plot_signal1)
        np.save('plot_noise5_ours.npy', plot_noise1)
        data = pd.DataFrame({'x': plot_signal_temp, 'y': plot_noise_temp})
        sns.scatterplot(x='x', y='y', data=data, marker='.')
        x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
        y = x
        data = pd.DataFrame({'x': x, 'y': y})
        sns.lineplot(x='x', y='y', data=data, color='red')
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)

    #print(print_signal6.shape)
    plot_signal6 = np.zeros((len(print_signal6)-1)*len(print_signal6[0])+len(print_signal6[-1]))
    plot_noise6 = np.zeros((len(print_noise6)-1)*len(print_noise6[0])+len(print_noise6[-1]))
    print(plot_signal6.shape)
    for i in range(len(print_signal6)):
        plot_signal6[i*len(print_signal6[0]):min((i+1)*len(print_signal6[0]),i*len(print_signal6[0])+len(print_signal6[i]))]=np.log(print_signal6[i])
        plot_noise6[i*len(print_noise6[0]):min((i+1)*len(print_noise6[0]),i*len(print_noise6[0])+len(print_noise6[i]))]=np.log(print_noise6[i])
    plt.subplot(236)
    data = pd.DataFrame({'x': plot_signal6, 'y': plot_noise6})
    sns.scatterplot(x='x', y='y', data=data, marker='.')
    if args.save_plot:
        np.save('plot_signal6.npy', plot_signal6)
        np.save('plot_noise6.npy', plot_noise6)
    if args.load_plot:
        plot_signal_temp = np.load('plot_signal6.npy')
        plot_noise_temp = np.load('plot_noise6.npy')
        np.save('plot_signal6_ours.npy', plot_signal6)
        np.save('plot_noise6_ours.npy', plot_noise6)
        data = pd.DataFrame({'x': plot_signal_temp, 'y': plot_noise_temp})
        sns.scatterplot(x='x', y='y', data=data, marker='.')
        x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
        y = x
        data = pd.DataFrame({'x': x, 'y': y})
        sns.lineplot(x='x', y='y', data=data, color='red')
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)
    plt.subplots_adjust(wspace=0.3)
    plt.savefig('image_change.jpg', dpi=1000)
    
    """ plt.rcParams["font.family"] = "serif"
    plt.rcParams["font.serif"] = "Times New Roman"
    fig, axs = plt.subplots(2, 3, figsize=(12, 8))
    #plot_signal1 = np.zeros((len(print_signal1)-1)*len(print_signal1[0])+len(print_signal1[-1]))
    #plot_noise1 = np.zeros((len(print_noise1)-1)*len(print_noise1[0])+len(print_noise1[-1]))
    #for i in range(len(print_signal1)):
    #    plot_signal1[i*len(print_signal1[0]):min((i+1)*len(print_signal1[0]),i*len(print_signal1[0])+len(print_signal1[i]))]=np.log(print_signal1[i])
    #    plot_noise1[i*len(print_noise1[0]):min((i+1)*len(print_noise1[0]),i*len(print_noise1[0])+len(print_noise1[i]))]=np.log(print_noise1[i])
    plot_signal1 = np.load('plot_signal1_ours.npy')
    plot_noise1 = np.load('plot_noise1_ours.npy')
    plot_signal_temp = np.load('plot_signal1.npy')
    plot_noise_temp = np.load('plot_noise1.npy')
    x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
    y = x
    axs[0, 0].scatter(plot_signal1, plot_noise1, marker='.', label='Ours')
    axs[0, 0].scatter(plot_signal_temp, plot_noise_temp, marker='.', label='SABR')
    axs[0, 0].plot(x, y, color='red')
    #axs[0, 0].legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=25)
    axs[0, 0].set_title('Conv1', fontsize=20)
    axs[0, 0].tick_params(axis='both', which='major', labelsize=18)

    #plot_signal2 = np.zeros((len(print_signal2)-1)*len(print_signal2[0])+len(print_signal2[-1]))
    #plot_noise2 = np.zeros((len(print_noise2)-1)*len(print_noise2[0])+len(print_noise2[-1]))
    #for i in range(len(print_signal2)):
    #    plot_signal2[i*len(print_signal2[0]):min((i+1)*len(print_signal2[0]),i*len(print_signal2[0])+len(print_signal2[i]))]=np.log(print_signal2[i])
    #    plot_noise2[i*len(print_noise2[0]):min((i+1)*len(print_noise2[0]),i*len(print_noise2[0])+len(print_noise2[i]))]=np.log(print_noise2[i])
    plot_signal2 = np.load('plot_signal2_ours.npy')
    plot_noise2 = np.load('plot_noise2_ours.npy')
    plot_signal_temp = np.load('plot_signal2.npy')
    plot_noise_temp = np.load('plot_noise2.npy')
    x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
    y = x
    axs[0, 1].scatter(plot_signal2, plot_noise2, marker='.', label='Ours')
    axs[0, 1].scatter(plot_signal_temp, plot_noise_temp, marker='.', label='SABR')
    axs[0, 1].plot(x, y, color='red')
    #axs[0, 1].legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=25)
    axs[0, 1].set_title('Conv2', fontsize=20)
    axs[0, 1].tick_params(axis='both', which='major', labelsize=18)

    #plot_signal3 = np.zeros((len(print_signal3)-1)*len(print_signal3[0])+len(print_signal3[-1]))
    #plot_noise3 = np.zeros((len(print_noise3)-1)*len(print_noise3[0])+len(print_noise3[-1]))
    #for i in range(len(print_signal3)):
    #    plot_signal3[i*len(print_signal3[0]):min((i+1)*len(print_signal3[0]),i*len(print_signal3[0])+len(print_signal3[i]))]=np.log(print_signal3[i])
    #    plot_noise3[i*len(print_noise3[0]):min((i+1)*len(print_noise3[0]),i*len(print_noise3[0])+len(print_noise3[i]))]=np.log(print_noise3[i])
    plot_signal3 = np.load('plot_signal3_ours.npy')
    plot_noise3 = np.load('plot_noise3_ours.npy')
    plot_signal_temp = np.load('plot_signal3.npy')
    plot_noise_temp = np.load('plot_noise3.npy')
    x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
    y = x
    axs[0, 2].scatter(plot_signal3, plot_noise3, marker='.', label='Ours')
    axs[0, 2].scatter(plot_signal_temp, plot_noise_temp, marker='.', label='SABR')
    axs[0, 2].plot(x, y, color='red')
    #axs[0, 2].legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=25)
    axs[0, 2].set_title('Conv3', fontsize=20)
    axs[0, 2].tick_params(axis='both', which='major', labelsize=18)

    #plot_signal4 = np.zeros((len(print_signal4)-1)*len(print_signal4[0])+len(print_signal4[-1]))
    #plot_noise4 = np.zeros((len(print_noise4)-1)*len(print_noise4[0])+len(print_noise4[-1]))
    #for i in range(len(print_signal4)):
    #    plot_signal4[i*len(print_signal4[0]):min((i+1)*len(print_signal4[0]),i*len(print_signal4[0])+len(print_signal4[i]))]=np.log(print_signal4[i])
    #    plot_noise4[i*len(print_noise4[0]):min((i+1)*len(print_noise4[0]),i*len(print_noise4[0])+len(print_noise4[i]))]=np.log(print_noise4[i])
    plot_signal4 = np.load('plot_signal4_ours.npy')
    plot_noise4 = np.load('plot_noise4_ours.npy')
    plot_signal_temp = np.load('plot_signal4.npy')
    plot_noise_temp = np.load('plot_noise4.npy')
    x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
    y = x
    axs[1, 0].scatter(plot_signal4, plot_noise4, marker='.', label='Ours')
    axs[1, 0].scatter(plot_signal_temp, plot_noise_temp, marker='.', label='SABR')
    axs[1, 0].plot(x, y, color='red')
    #axs[1, 0].legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=25)
    axs[1, 0].set_title('Conv4', fontsize=20)
    axs[1, 0].tick_params(axis='both', which='major', labelsize=18)

    #plot_signal5 = np.zeros((len(print_signal5)-1)*len(print_signal5[0])+len(print_signal5[-1]))
    #plot_noise5 = np.zeros((len(print_noise5)-1)*len(print_noise5[0])+len(print_noise5[-1]))
    #for i in range(len(print_signal5)):
    #    plot_signal5[i*len(print_signal5[0]):min((i+1)*len(print_signal5[0]),i*len(print_signal5[0])+len(print_signal5[i]))]=np.log(print_signal5[i])
    #    plot_noise5[i*len(print_noise5[0]):min((i+1)*len(print_noise5[0]),i*len(print_noise5[0])+len(print_noise5[i]))]=np.log(print_noise5[i])
    plot_signal5 = np.load('plot_signal5_ours.npy')
    plot_noise5 = np.load('plot_noise5_ours.npy')
    plot_signal_temp = np.load('plot_signal5.npy')
    plot_noise_temp = np.load('plot_noise5.npy')
    x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
    y = x
    axs[1, 1].scatter(plot_signal5, plot_noise5, marker='.', label='Ours')
    axs[1, 1].scatter(plot_signal_temp, plot_noise_temp, marker='.', label='SABR')
    axs[1, 1].plot(x, y, color='red')
    #axs[1, 1].legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=25)
    axs[1, 1].set_title('Conv5', fontsize=20)
    axs[1, 1].tick_params(axis='both', which='major', labelsize=18)

    #plot_signal6 = np.zeros((len(print_signal6)-1)*len(print_signal6[0])+len(print_signal6[-1]))
    #plot_noise6 = np.zeros((len(print_noise6)-1)*len(print_noise6[0])+len(print_noise6[-1]))
    #for i in range(len(print_signal6)):
    #    plot_signal6[i*len(print_signal6[0]):min((i+1)*len(print_signal6[0]),i*len(print_signal6[0])+len(print_signal6[i]))]=np.log(print_signal6[i])
    #    plot_noise6[i*len(print_noise6[0]):min((i+1)*len(print_noise6[0]),i*len(print_noise6[0])+len(print_noise6[i]))]=np.log(print_noise6[i])
    plot_signal6 = np.load('plot_signal6_ours.npy')
    plot_noise6 = np.load('plot_noise6_ours.npy')
    plot_signal_temp = np.load('plot_signal6.npy')
    plot_noise_temp = np.load('plot_noise6.npy')
    x = np.linspace(np.min(plot_signal_temp)-0.5, np.max(plot_signal_temp)+0.5)
    y = x
    axs[1, 2].scatter(plot_signal6, plot_noise6, marker='.', label='Ours')
    axs[1, 2].scatter(plot_signal_temp, plot_noise_temp, marker='.', label='SABR')
    axs[1, 2].plot(x, y, color='red')
    #axs[1, 2].legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=25)
    axs[1, 2].set_title('FC', fontsize=20)
    axs[1, 2].tick_params(axis='both', which='major', labelsize=18)
    
    lines, labels = fig.axes[-1].get_legend_handles_labels()
    fig.legend(lines, labels, bbox_to_anchor=(0.5, 0.95), loc = 'upper center', fontsize=25, ncol=2, frameon=False)
    #fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), fontsize=25)
    fig.text(0.55, 0.02, 'Logarithm of Signal', ha='center', va='center', fontsize=25)
    fig.text(0.04, 0.5, 'Logarithm of Noise', ha='center', va='center', rotation='vertical', fontsize=25)
    plt.tight_layout(rect=[0.05, 0.05, 1, 0.9])
    plt.savefig('image_change.jpg')
    #tikz_save("SignalandNoise.tex") """



    # save metrics
    if args.save:
        if train:
            writer.add_scalar('kappa', kappa, epoch)
            writer.add_scalar('eps', eps, epoch)
            writer.add_scalar('tau', tau, epoch)
            writer.add_scalar('train_stand_acc', nat_ok / n_samples, epoch)
            writer.add_scalar('train_rob_acc', abs_tau_ok / n_samples, epoch)
            writer.add_scalar('train_loss', loss_total / len(pbar), epoch)
            writer.add_scalar('train_normal_loss', normal_loss_total / len(pbar), epoch)
            writer.add_scalar('train_robust_loss', robust_tau_loss_total / len(pbar), epoch)
            writer.add_scalar('train_reg', reg / len(pbar), epoch)
            writer.add_scalar('train_time', time_epoch, epoch)
        else:
            writer.add_scalar('test_stand_acc', nat_ok / n_samples, epoch)
            writer.add_scalar('test_rob_acc', abs_tau_ok / n_samples, epoch)
            writer.add_scalar('test_loss', loss_total / len(pbar), epoch)
            writer.add_scalar('test_normal_loss', normal_loss_total / len(pbar), epoch)
            writer.add_scalar('test_robust_loss', robust_tau_loss_total / len(pbar), epoch)
            writer.add_scalar('test_time', time_epoch, epoch)
    test_rob_acc = abs_tau_ok / n_samples
    return test_rob_acc

