import argparse
import torch
import numpy as np
import torch.nn as nn
from torch.autograd import Variable
from torch.autograd import Function
#import opts
import time
import math
from torch.utils.data import Dataset, DataLoader
import os
def train_two_regression_gradgrad_version(A_out1,B_out1,real1,A_out2,B_out2,real2,loss_f,opt_NN1,Data_opt1=None,Data_opt2=None,x_data=None):

    opt_NN1.zero_grad()


    Data_opt1.zero_grad()
    loss_B=loss_f(A_out1,real1)
    loss_A =loss_f(B_out1, real1)
    loss_train=loss_A+loss_B
    loss_train.backward(creat_graph=True)
    opt_NN1.step()
    Data_opt1.zero_grad()
    print(x_data.grad)
    Data_opt2.zero_grad()
    loss_data=loss_f(A_out2,real2)-loss_f(B_out2,real2)
    loss_data.backward()
    print(x_data.grad)
    Data_opt1.step()
    Data_opt2.step()
    return loss_data.item()
class A_Dataset(Dataset):
    def __init__(self, features,labels=None):

        self.features=features
        self.labels = labels

    def __len__(self):
        return len(self.features)
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if (self.labels is None):
            sample = {'feature': self.features[idx]}
        else:
            sample = {'feature': self.features[idx], 'label': self.labels[idx]}
        return sample

def get_MSE_np(pred, real):
    return np.mean(np.power(real - pred, 2))


def get_MAE_np(pred, real):
    return np.mean(np.abs(real - pred))


def get_MAPE_np(pred, real):
    ori_real = real.copy()
    epsilon = 1
    real[real == 0] = epsilon
    return np.mean(np.abs((ori_real - pred) / real))


class linear_adjust_width(nn.Module):#
    def __init__(self,linear_max_width,init_width,train_width=True):
        super(linear_adjust_width, self).__init__()
        self.linear_max_width=linear_max_width
        self.width_masked = torch.nn.Parameter(torch.ones([1])*init_width,requires_grad=train_width)#(1)

        self.width_mask =torch.nn.Parameter(torch.range(1, linear_max_width),requires_grad=False)#(linear_max_width),from 1

    def forward(self, input,train_width=True):
        if(train_width is True):
            width_masked=self.width_masked
        else:
            width_masked = self.width_masked.detach()
        #wide_now=self.linear_max_width-width_masked.item()
        sigmomd_mask=torch.sigmoid((self.width_mask.detach()-width_masked)*4)##(linear_max_width)
        out=input*sigmomd_mask#(batch,linear_max_width)
        return out

