from utils import test_eval
from model import *

class ARCDetector:
    def __init__(self, train_config, model_config, data):
        self.model_config = model_config
        self.train_config = train_config
        self.data = data
        self.model = ARC_New(device=train_config['device'], **model_config).to(train_config['device'])
        self.pattern = dict()
        self.struct_pattern = dict()
        self.n_support = model_config['n_support']
        self.model.mask_ratio = model_config['mask_ratio']
        self.model.temperature = model_config['temperature']
        self.model.domain_sim.temperature = model_config['temperature']

    def train(self):
        # Training
        self.model.train()
        n_support = self.n_support
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.model_config['lr'], weight_decay=self.model_config['weight_decay'])
        import time
        start_time = time.time()
        for e in range(self.train_config['epochs']):
            for didx, train_data in enumerate(self.data['train']):
                ## extract patterns
                train_graph = self.data['train'][didx].graph.to(self.train_config['device'])
                self.model.graph_idx = didx
                if not hasattr(train_graph, 'one_node_features'):
                    train_graph.one_node_features = torch.ones(size=(train_graph.num_nodes, self.model.st_dim)).to(self.train_config['device'])
                train_graph_emb, struct_emb = self.model.get_embedding(train_graph, train_graph.adj)
                patterns, struct_patterns = self.model.patterns_extraction(train_graph_emb, struct_emb, train_graph.adj, train_graph.ano_labels, num_prompt=n_support)
                self.pattern[didx] = patterns.detach()

                self.struct_pattern[didx] = struct_patterns.detach()

                loss = self.model(train_graph_emb, struct_emb, train_graph.adj, train_graph.ano_labels,  self.pattern, self.struct_pattern, num_prompt=n_support)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                # scheduler.step()
                print('Epoch [{}/{}], Loss: {:.4f}'.format(e, self.train_config['epochs'], loss))

                ## update the node representation
                train_graph_emb, struct_emb = self.model.get_embedding(train_graph, train_graph.adj)
                patterns, struct_patterns = self.model.patterns_extraction(train_graph_emb, struct_emb, train_graph.adj, train_graph.ano_labels,
                                                          num_prompt=n_support)
                self.pattern[didx] = patterns.detach()
                self.struct_pattern[didx] = struct_patterns.detach()


        print('Finish Training for {} epochs!'.format(self.train_config['epochs']))
        print('Time: {}', time.time() - start_time)

        ## extract patterns
        self.model.eval()
        for didx, train_data in enumerate(self.data['train']):
            train_graph = self.data['train'][didx].graph.to(self.train_config['device'])
            train_graph_emb, struct_emb = self.model.get_embedding(train_graph, train_graph.adj)
            patterns, struct_patterns  = self.model.patterns_extraction(train_graph_emb, struct_emb, train_graph.adj, train_graph.ano_labels, num_prompt=n_support)
            self.pattern[didx] = patterns
            self.struct_pattern[didx] = struct_patterns.detach()
        similarity = dict()
        # Evaluation
        test_score_list = {}
        for didx, test_data in enumerate(self.data['test']):
            ## extract patterns from each test graph as new knowledge
            test_graph = test_data.graph.to(self.train_config['device'])
            shot_mask = test_graph.shot_mask.bool()
            if not hasattr(test_graph, 'one_node_features'):
                test_graph.one_node_features = torch.ones(size=(test_graph.num_nodes, self.model.st_dim)).to(self.train_config['device'])
            test_graph_emb, struct_emb = self.model.get_embedding(test_graph, test_graph.adj)
            labels = test_graph.ano_labels
            query_labels = labels[~shot_mask].to(self.train_config['device'])
            patterns_test, struct_patterns = self.model.patterns_extraction_for_test_graph(test_graph_emb, struct_emb, 10)
            self.pattern[len(self.data['train'])] = patterns_test.detach()
            self.struct_pattern[len(self.data['train'])] = struct_patterns.detach()
            ## detect anomalies in each test graph
            query_scores, dom_sim = self.model.inference(self.pattern, test_graph_emb, test_graph.adj, self.struct_pattern, struct_emb)
            test_score = test_eval(query_labels, query_scores[~shot_mask])
            # Store the test scores in the dictionary
            test_data_name = self.train_config['testdsets'][didx]
            test_score_list[test_data_name] = {
                'AUROC': test_score['AUROC'],
                'AUPRC': test_score['AUPRC'],
            }
            similarity[test_data_name] = dom_sim.mean(dim=1)
        return test_score_list, similarity
