import sys
import csv
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import copy
import random
import pickle
def load_object(file_path):
    with open(file_path, 'rb') as file:
        obj = pickle.load(file)
    return obj

subject_class = input("Input the target group. (A or B)")
subject_number = input("Input the model number. (Group A : 1-62, Group B : 62-121)")
assignment= input("Please select the target assignment for analysis. (1 or 2 or 3)")

data_file_folder_groupA="../1.subject_measurement_data/Group_A_human_data_main/"  # modify this path on your setting
data_file_folder_groupB="../1.subject_measurement_data/Group_B_human_data_main/"  # modify this path on your setting
loaded_img_data = load_object('../train_img_128.pkl') # modify this path on your setting (You should <0. make image numpy file> first

if subject_class=="a":
    subject_class ="A"
elif subject_class=="b":
    subject_class="B"


batch_size = 32
learning_rate = 0.0001
num_epochs = 5000


dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


if subject_class=="A":
    training_data_num_=list(range(11, 511))
    with open('./random_triplet_assignment/problemA'+str(assignment)+'.pkl', 'rb') as f:
        ele_test=pickle.load(f)
        ele_val=pickle.load(f)
else:
    training_data_num_=list(range(11+510, 511+510))
    with open('./random_triplet_assignment/problemB'+str(assignment)+'.pkl', 'rb') as f:
        ele_test=pickle.load(f)
        ele_val=pickle.load(f)



random.shuffle(training_data_num_)

training_data_num=[]
val_data_num=[]
test_data_num=[]
for el in training_data_num_:
    if el in ele_val:
        val_data_num.append(el)
    elif el in ele_test:
        test_data_num.append(el)
    else:
        training_data_num.append(el)



def reading_csv(subject_class, subject_num):

    if subject_class=="A":
        path = data_file_folder_groupA+str(subject_num)+".csv"
    else:
        path = data_file_folder_groupB+str(subject_num)+".csv"

    f = open(path, 'r')
    rdr = csv.reader(f)


    S_problem_data=[]

    rdr2=[]
    for line in rdr:
        rdr2.append(line)
    if subject_class=='A':
        for ii in range(500):

            for i, line in enumerate(rdr2):
                if i>0:
                    if int(line[1])==ii+11:
                        S_problem_data.append([[line[1]],[line[3],line[4],line[5]],[line[15],line[16],line[17]],[line[12],line[13],line[14]]])

    if subject_class=='B':
        for ii in range(500):
            for i, line in enumerate(rdr2):
                if i>0:
                    if int(line[1])==ii+510+11:
                        S_problem_data.append([[line[1]],[line[3],line[4],line[5]],[line[15],line[16],line[17]],[line[12],line[13],line[14]]])

    return S_problem_data

S_problem_data_list=[]
subject_A_list=list(range(1,63))
subject_B_list=list(range(63,122))
if subject_class=="A":
    subject_A_list.remove(int(subject_number))
    for element in subject_A_list:
        S_problem_data_list.append(reading_csv(subject_class, element))

elif subject_class=="B":
    subject_B_list.remove(int(subject_number))
    for element in subject_B_list:
        S_problem_data_list.append(reading_csv(subject_class, element))



class Autoencoder(nn.Module):  #128
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.conv1 = nn.Conv2d(1,16,3,2,1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, 2, 1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 64, 3, 2, 1)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(64, 128, 3, 2, 1)
        self.relu4 = nn.ReLU()

        self.fc1 = nn.Linear(128*8*8, 64)
        self.fc2= nn.Linear(64, 128*8*8)

        self.urelu1= nn.ReLU()
        self.upconv1=nn.ConvTranspose2d(128,64,4,2,1)
        self.urelu2 = nn.ReLU()
        self.upconv2 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.urelu3 = nn.ReLU()
        self.upconv3 = nn.ConvTranspose2d(32, 16, 4, 2, 1)
        self.urelu4 = nn.ReLU()
        self.upconv4 = nn.ConvTranspose2d(16, 1, 4, 2, 1)
        self.sig=nn.Sigmoid()

    def forward(self, x):
        x= self.conv1(x)
        x= self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = torch.flatten(x, 1)
        emb= self.fc1(x)
        x=self.fc2(emb)
        x=self.urelu1(x)
        x = x.view(-1, 128, 8, 8)
        x=self.upconv1(x)
        x=self.urelu2(x)
        x=self.upconv2(x)
        x=self.urelu3(x)
        x=self.upconv3(x)
        x=self.urelu4(x)
        x=self.upconv4(x)
        x=self.sig(x)

        return x, emb


class CustomDataset(Dataset):
    def __init__(self, numpy_data, transform=None):
        self.data = torch.from_numpy(numpy_data).float()
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample




class CustomDataset_new(Dataset):

    def __init__(self, conc_data, img_dic, transform=None):
        self.img_dic=img_dic
        self.conc = conc_data


        self.main_data1 = []
        self.main_data2 = []
        self.main_data3 = []
        for index, element in enumerate(conc_data):


            self.main_data1.append(img_dic[element[1][0]])
            if float(element[2][0]) <= float(element[2][2]):
                self.main_data2.append(img_dic[element[1][1]])
                self.main_data3.append(img_dic[element[1][2]])
            else:
                self.main_data2.append(img_dic[element[1][2]])
                self.main_data3.append(img_dic[element[1][1]])



            self.main_data1.append(img_dic[element[1][1]])
            if float(element[2][0]) <= float(element[2][1]):
                self.main_data2.append(img_dic[element[1][0]])
                self.main_data3.append(img_dic[element[1][2]])
            else:
                self.main_data2.append(img_dic[element[1][2]])
                self.main_data3.append(img_dic[element[1][0]])


            self.main_data1.append(img_dic[element[1][2]])
            if float(element[2][2]) <= float(element[2][1]):
                self.main_data2.append(img_dic[element[1][0]])
                self.main_data3.append(img_dic[element[1][1]])
            else:
                self.main_data2.append(img_dic[element[1][1]])
                self.main_data3.append(img_dic[element[1][0]])


        self.transform = transform


        self.main_data1=torch.from_numpy(np.array(self.main_data1))
        self.main_data2 = torch.from_numpy(np.array(self.main_data2))
        self.main_data3 = torch.from_numpy(np.array(self.main_data3))
    def __len__(self):
        return len(self.main_data1)

    def __getitem__(self, idx):

        sample1=self.main_data1[idx]
        sample2=self.main_data2[idx]
        sample3=self.main_data3[idx]

        if self.transform:
            sample1 = self.transform(sample1)
            sample2 = self.transform(sample2)
            sample3 = self.transform(sample3)


        return sample1,sample2,sample3

class CustomDataset_new2(Dataset):

    def __init__(self, conc_data, img_dic, transform=None):
        self.img_dic=img_dic  #
        self.conc = conc_data #


        self.main_data1 = []
        self.main_data2 = []
        self.main_data3 = []

        self.sim1=[] #
        self.sim2=[] #
        for index, element in enumerate(conc_data):


            self.main_data1.append(img_dic[element[1][0]])
            if float(element[2][0]) <= float(element[2][2]):
                self.main_data2.append(img_dic[element[1][1]])
                self.main_data3.append(img_dic[element[1][2]])
                self.sim1.append(float(element[2][0]))
                self.sim2.append(float(element[2][2]))
            else:
                self.main_data2.append(img_dic[element[1][2]])
                self.main_data3.append(img_dic[element[1][1]])
                self.sim1.append(float(element[2][2]))
                self.sim2.append(float(element[2][0]))


            self.main_data1.append(img_dic[element[1][1]])
            if float(element[2][0]) <= float(element[2][1]):
                self.main_data2.append(img_dic[element[1][0]])
                self.main_data3.append(img_dic[element[1][2]])
                self.sim1.append(float(element[2][0]))
                self.sim2.append(float(element[2][1]))
            else:
                self.main_data2.append(img_dic[element[1][2]])
                self.main_data3.append(img_dic[element[1][0]])
                self.sim1.append(float(element[2][1]))
                self.sim2.append(float(element[2][0]))

            self.main_data1.append(img_dic[element[1][2]])
            if float(element[2][2]) <= float(element[2][1]):
                self.main_data2.append(img_dic[element[1][0]])
                self.main_data3.append(img_dic[element[1][1]])
                self.sim1.append(float(element[2][2]))
                self.sim2.append(float(element[2][1]))
            else:
                self.main_data2.append(img_dic[element[1][1]])
                self.main_data3.append(img_dic[element[1][0]])
                self.sim1.append(float(element[2][1]))
                self.sim2.append(float(element[2][2]))


        self.transform = transform
        self.main_data1=np.array(self.main_data1)
        self.main_data2=np.array(self.main_data2)
        self.main_data3=np.array(self.main_data3)
        self.sim1=np.array(self.sim1)
        self.sim2=np.array(self.sim2)

        self.main_data1=torch.from_numpy((self.main_data1))
        self.main_data2 = torch.from_numpy((self.main_data2))
        self.main_data3 = torch.from_numpy((self.main_data3))
        self.sim1=torch.from_numpy((self.sim1))
        self.sim2 = torch.from_numpy((self.sim2))
    def __len__(self):
        return len(self.main_data1)

    def __getitem__(self, idx):

        sample1=self.main_data1[idx]
        sample2=self.main_data2[idx]
        sample3=self.main_data3[idx]

        sample4=self.sim1[idx]
        sample5=self.sim2[idx]

        if self.transform:
            sample1 = self.transform(sample1)
            sample2 = self.transform(sample2)
            sample3 = self.transform(sample3)

        return sample1,sample2,sample3, sample4, sample5


class ToTensor:
    def __call__(self, sample):
        return torch.as_tensor(sample, dtype=torch.float32)


model = Autoencoder().to(device)
model_path = "../Model_folder/" + str(subject_number) + ".pth"  # When the model file changes, please modify this part accordingly.
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()


evaluation_result=[]
for S_problem_data in S_problem_data_list:
    loaded_meta_data = S_problem_data

    conc_training_data=[]
    for a in loaded_meta_data:
        #print(a[0])
        if int(a[0][0]) in training_data_num:
            conc_training_data.append(a)

    conc_val_data=[]
    for a in loaded_meta_data:

        if int(a[0][0]) in val_data_num:
            conc_val_data.append(a)

    conc_test_data=[]
    for a in loaded_meta_data:

        if int(a[0][0]) in test_data_num:
            conc_test_data.append(a)


    custom_dataset_training2 = CustomDataset_new2(conc_training_data, loaded_img_data, transform=ToTensor())
    train_loader2 = DataLoader(custom_dataset_training2, batch_size=batch_size, shuffle=True)

    custom_dataset_val = CustomDataset_new(conc_val_data, loaded_img_data, transform=ToTensor())
    val_loader = DataLoader(custom_dataset_val, batch_size=batch_size, shuffle=True)

    custom_dataset_test = CustomDataset_new(conc_test_data, loaded_img_data, transform=ToTensor())
    test_loader = DataLoader(custom_dataset_test, batch_size=batch_size, shuffle=True)


    def criterion4 (output_tensor, input_img, img1_emb, img2_emb, img3_emb, sim1, sim2, experiment_setting):
        squared_diff = (output_tensor - input_img) ** 2


        noise1 = (torch.randn(img2_emb.shape)*0.01).to(device)
        noise2 = (torch.randn(img3_emb.shape)*0.01).to(device)

        sim_diff_s = sim2*torch.mean((img1_emb - img2_emb.detach()+noise1) ** 2, dim=1)
        sim_diff_l = sim1*torch.mean((img1_emb - img3_emb.detach()+noise2) ** 2, dim=1)


        if int(experiment_setting) == 0: # starndatd setting
            mse_loss = 1.2 * torch.mean(sim_diff_s) + torch.mean(sim_diff_l) + torch.mean(squared_diff)
        elif int(experiment_setting) == 1: # ablation of triplet loss
            mse_loss =torch.mean(squared_diff)
        else : # ablation of recontruction loss
            mse_loss = 1.2*torch.mean(sim_diff_s) + torch.mean(sim_diff_l)  # 1.5

        return mse_loss

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


    def measure(emb1, emb2, emb3):
        acc=[]
        emb1=emb1.cpu().numpy()
        emb2= emb2.cpu().numpy()
        emb3 = emb3.cpu().numpy()

        for i in range(len(emb1)):
            if np.mean((emb1[i] - emb2[i])**2) < np.mean((emb1[i] - emb3[i])**2):
                acc.append(1)
            else:
                acc.append(0)
        return acc


    def get_accuracy(model, loader, letter):

        accuracy = []
        with torch.no_grad():

            for batch_idx, (data1, data2, data3) in enumerate(loader):
                img1 = data1.to(device)
                img1 = img1.unsqueeze(1)
                _, img1_emb = model(img1)

                img2 = data2.to(device)
                img2 = img2.unsqueeze(1)
                _, img2_emb = model(img2)

                img3 = data3.to(device)
                img3 = img3.unsqueeze(1)  #
                _, img3_emb = model(img3)

                accuracy+=measure(img1_emb, img2_emb, img3_emb)


        return np.mean(accuracy)



    save_model = copy.deepcopy(model)

    test_acc = get_accuracy(model, test_loader, "test start")

    evaluation_result.append(float(test_acc))

print("NSE for model", subject_number, " : " , np.mean(evaluation_result))