import torch
import json
import sys
sys.path.append("../")
#from prepare_data.old.construct_graphs_from_json import JsonToGraph
from dataset.json_graph import JsonToGraph,NewNodeException
from Autoregressive_model import AutoregressiveModel
from dataset.data_gan import MyDataset
import argparse
import importlib
from types import SimpleNamespace
import socket
import time
import types
from torch_geometric.loader import DataLoader
from dataset.gamedata import GameDataset
import os
import main
import numpy as np
class NoveltyDetector:
    def __init__(self,device):
        
        args = types.SimpleNamespace()
        args.emb_dim = 32
        args.test_dataset_dir = "/home/plymper/data/polycraftv2/nov_mapless"
        args.val_dataset_dir = "./"
        args.model_type = 'GraphSAGE'
        args.model_save_dir = f'./polycraft_results/'
        args.task = 'gridworld'
        args.net_type=args.model_type
        self.config = args

        self.device = device
        
        self.dataset = GameDataset(data_path=args.val_dataset_dir, concat_steps=5, mode='test') #use same jgraph obj as validation set
        
        self.model = AutoregressiveModel(self.dataset.num_nodes, self.dataset.node_feature_dim, self.dataset.node_info,model= args.model_type, config = args)
        self.model.to(device)
        self.model.load_model(os.path.join(args.model_save_dir,args.model_type+".pth"))
        

        self.model.model.eval()
        
        self.loss_percentile = [[]]
        self.losses = [[]]
        self.predictions = {}
        self.index = 0
        self.new_string_detected = False
        self.most_likely_novelty = None
        self.new_episode = False
        self.step_n = 0
        
        return

    def reset(self):
        self.new_episode=True
        self.loss_percentile.append([])
        self.losses.append([])
        self.new_string_detected = False
        self.most_likely_novelty = None
        self.step_n=0

        return

    def update_most_likely_novelty(self, latest_pred):
        if self.most_likely_novelty is None:
            self.most_likely_novelty=latest_pred
        if self.most_likely_novelty['percentile']< latest_pred['percentile']:
            self.most_likely_novelty = latest_pred
        return

    def process(self, json_obj):
        try:
            self.dataset.receive_json_obj(json_obj, task = self.config.task, new_episode=self.new_episode)
            self.new_episode=False
        except (NewNodeException, KeyError) as e:
            self.predictions[self.index]= None
            self.index+=1
            self.loss_percentile[-1].append(None)
            self.losses[-1].append(None)
            self.new_string_detected= True
            print("NewStringDetected")
            return

        self.new_string_detected=False
        graph = self.dataset[-1]
        if graph is None or self.step_n<10:
            print(None)
        else:
            with torch.no_grad():
                x = graph.to('cpu')
                #perc, loss, pred_json, true_json, diff_json = main.predict_json(self.dataset, self.model, json_obj, x)
                #self.predictions[self.index] = {"true":true_json,"prediction":pred_json, "loss":loss, "percentile":perc,"diff_json":diff_json}
                perc,loss,pred = self.model.compute_novelty_score(x,return_prediction=True)
                self.predictions[self.index] = {"loss":loss,"pred":pred,"percentile":perc['graph']}
                self.loss_percentile[-1].append(perc['graph'])
                self.losses[-1].append(loss)
                self.update_most_likely_novelty(self.predictions[self.index])
                self.index+=1
                print(np.max(perc['numerical']), perc['graph'], loss)
        self.step_n+=1
                
    
    def produce_response(self):
        if self.most_likely_novelty is None:
            return "None"
        if self.new_string_detected:
            return "New_String_Detected"

        if self.most_likely_novelty['percentile']>=1:
            return  "Novelty_Detected"
        return "None"



def recv_json(sock):
    '''
        receives json data from socket
    '''
    BUFF_SIZE = 4096  # 4 KiB
    data = b''
    while True:
        time.sleep(0.00001)
        part = sock.recv(BUFF_SIZE)
        data += part
        if len(part) < BUFF_SIZE:
            # either 0 or end of data
            break

    return data
   


if __name__=="__main__":

    parser = argparse.ArgumentParser(description='detector')
    parser.add_argument("--node_path", type=str, default= "./",
                        help="Path to node dict")
    parser.add_argument("--port", type=int, default=6012,
                        help="port to connect to")
    
    args = parser.parse_args()
    

    nvdet = NoveltyDetector('cpu')

    HOST = '127.0.0.1'  # Standard loopback interface address (localhost)
    PORT = args.port        # Port to listen on (non-privileged ports are > 1023)

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind((HOST, PORT))
        s.listen()
        conn, addr = s.accept()
        with conn:
            print('NOVDET Connected by', addr, flush = True)
            while True:
                
                data = recv_json(conn)
                if len(data)<1:
                    continue
                
                message = data.decode()
                print(message)
                if message == "SEND_DETECTIONS\n":
                    print("RESPONDING")
                    response = nvdet.produce_response()
                    conn.send((str(response)+'\n').encode())
                elif message == "RESET\n":
                    print("RESETTING")
                    nvdet.reset()
                else:
                    print(message)
                    js = json.loads(message)
                    response = nvdet.process(js)
                    print("NOVDET PROCESSED",flush = True)
                



    
