import pandas as pd
import torch
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt


class CustomDataset(Dataset):
	def __init__(self, csv_path):
		df = pd.read_csv(csv_path)

		self.inp = df.iloc[:, :-1].values
		self.outp = df.iloc[:,-1].values.reshape(len(df),1)
	
	def __len__(self):
		return len(self.inp) 

	def __getitem__(self,idx):
		inp = torch.FloatTensor(self.inp[idx])
		outp = torch.FloatTensor(self.outp[idx])
		return inp, outp 

         
def contour_2d(data, feature_idx, model):

    
    feature_1 = data[:,feature_idx[0]]
    feature_2 = data[:,feature_idx[1]]


    x1_grid = np.linspace(0,1,100)  
    x2_grid = np.linspace(0,1,100) 
    data_idx = 20 

    X_grid = np.array(np.meshgrid(x1_grid,x2_grid)).reshape(2,100*100).T
    dt = np.repeat(data[data_idx,:], repeats= 10000, axis = 0)
    dt.resize(2,10000)
    dt = dt.transpose()
    dt[:,feature_idx[0]] = X_grid[:,0]
    dt[:,feature_idx[1]] = X_grid[:,1]

    Y = model(torch.tensor(dt,dtype=torch.float32 ))
    # Y = 1*np.sin((dt[:,0])*(25/np.pi)) + 1*((dt[:,0]-0.5)**3) + 1*np.exp(dt[:,1]) + dt[:,1]**2 # true_y
    Y = Y.reshape(100,100)


    fig = plt.figure(figsize=(6,4))
    ax = fig.add_subplot(111)
    
    # plt.xlabel('x1')
    # plt.ylabel('x2')
    #plt.title('2d contour plot')
    CS = ax.contour(x1_grid,x2_grid, Y.detach().numpy(),levels=10,linewidths=0.5,colors ='k')
    cntr = ax.contourf(x1_grid,x2_grid, Y.detach().numpy(), levels=10, cmap="RdBu_r")
    #
    label_font = {
        'fontsize': 14,
        #'fontweight': 'bold'
    }  
    #plt.colorbar(cntr)
    # plt.xlabel("x",fontdict=label_font)
    # plt.ylabel("y",fontdict=label_font)

    ax.clabel(CS, inline=1, fontsize=8, colors ='k')
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    plt.savefig('./contour.png', bbox_inches='tight',dpi=1000)
    #plt.show()


def contour_2d_original(data, feature_idx, model):
    feature_1 = data[:, feature_idx[0]]
    feature_2 = data[:, feature_idx[1]]

    x1_grid = np.linspace(0, 1, 100)
    x2_grid = np.linspace(0, 1, 100)
    data_idx = 20

    X_grid = np.array(np.meshgrid(x1_grid, x2_grid)).reshape(2, 100 * 100).T
    dt = np.repeat(data[data_idx, :], repeats=10000, axis=0)
    dt.resize(2, 10000)
    dt = dt.transpose()
    dt[:, feature_idx[0]] = X_grid[:, 0]
    dt[:, feature_idx[1]] = X_grid[:, 1]

    # Y = model(torch.tensor(dt, dtype=torch.float32))
    Y = 1*np.sin((dt[:,0])*(25/np.pi)) + 1*((dt[:,0]-0.5)**3) + 1*np.exp(dt[:,1]) + dt[:,1]**2 # true_y
    Y = Y.reshape(100, 100)

    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111)

    # plt.xlabel('x1')
    # plt.ylabel('x2')
    # plt.title('2d contour plot')
    CS = ax.contour(x1_grid, x2_grid, Y, levels=10, linewidths=0.5, colors='k')
    cntr = ax.contourf(x1_grid, x2_grid, Y, levels=10, cmap="RdBu_r")
    #
    label_font = {
        'fontsize': 14,
        # 'fontweight': 'bold'
    }
    # plt.colorbar(cntr)
    # plt.xlabel("x", fontdict=label_font)
    # plt.ylabel("y", fontdict=label_font)

    ax.clabel(CS, inline=1, fontsize=8, colors='k')
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    plt.savefig('./contour_original.png', bbox_inches='tight', dpi=1000)




