import matplotlib.pyplot as plt
import numpy as np
import math
import csv
from ast import literal_eval
import seaborn as sns
from scipy import stats

import pandas as pd
from torch.utils.data import Dataset

class DataLoader(Dataset):
    def __init__(self, filename):
        self.df = pd.read_csv(filename, converters={'z': literal_eval})
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        input_z = self.df.loc[index, 'z']
        
        return np.asarray(input_z)


csv_path = '/data/datasets/traffic/traffic_10x8x100/train/train_z_label.csv'
indices = [2, 8, 9] # Server2: Traffic_128_c11_0.15_semi0.1p_traffic
sign = 'comparison'

dataloader = DataLoader(csv_path)
# print(dataloader.df)
z_color = [] #2
z_shape = [] #9
z_orien = [] #8

for i in range(len(dataloader)):
    z = dataloader[i]
    z_color.append(z[2])
    z_shape.append(z[9])
    z_orien.append(z[8])
    
# print(len(z_color))
print(max(z_color))
print(min(z_color))
# sns.kdeplot(z_color)
# y = np.random.gamma(6,size=200)
                  
# sns.distplot(z_color)
def plot_distribution(x, path, fig_name, fit=True):
    plt.figure()
    if fit:
        sns.distplot(x, fit=stats.gamma)
    else:
        sns.displot(x)
    plt.xlabel('z')
    plt.title(fig_name)
    plt.savefig(path)
    
plot_distribution(z_color, f'./anomoly/{sign}/color.png', 'color distribution')
plot_distribution(z_shape, f'./anomoly/{sign}/shape.png', 'shape distribution')
plot_distribution(z_orien, f'./anomoly/{sign}/orient.png', 'orient distribution')

# with open(csv_path) as csvfile:
#     reader = csv.reader(csvfile, delimiter=' ', quotechar='|')
#     for row in reader:
#         print(', '.join(row))
#         input()


