import logging
import argparse
import os
import json
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import re
import torch

from utils import (
    parse_triangle_sequence, is_valid_triangle, generate, DATA_ROOT, HASH_STR_LEN, T_TRIANGLES, dataset_dir,
    MODEL_PATH, set_seed
)

sns.set_theme(style="white")


def complete(s):
    if not s.endswith(">"):
        s = s + ">"
    if not s.startswith("<"):
        s = "<" + s
    return s


class Evaluator:
    def __init__(self, model_path, dataset, data_dir, device="cpu"):
        self.model_path = model_path
        self.device = device
        
        # Load model and tokenizer
        self.model, self.tokenizer = self.load_model_and_tokenizer(model_path, device)
        
        # Load edges data
        edges_data = []
        for i in range(3):  # Assuming 3 graphs
            edges_path = os.path.join(data_dir, dataset, f"edges_{i}.json")
            if os.path.exists(edges_path):
                with open(edges_path, "r", encoding='utf-8') as f:
                    edges_data.append(json.load(f))
        self.edges_data = edges_data
        
        # Load training sequences for memorization evaluation
        with open(os.path.join(data_dir, dataset, "train.json"), "r", encoding='utf-8') as f:
            train_sequences = json.load(f)
            # Clean up the train sequences
            self.train_sequences = []
            for i in range(len(train_sequences)):
                if "tri:" in train_sequences[i]["target_text"]:
                    seq = train_sequences[i]["target_text"].split("tri:")[1].strip()
                    seq = self.canonicalize(seq)
                    graph_idx = train_sequences[i]["target_text"].split("tri:")[0].strip()
                    self.train_sequences.append((graph_idx, seq))
    
    def load_model_and_tokenizer(self, model_path, device="cpu"):
        """Load model and tokenizer from pretrained checkpoint"""
        # from rollthedice.triangle_discovery.utils import TriangleTokenizer, TransformerPolicy
        from utils import TriangleTokenizer, TransformerPolicy
        
        model_dict = torch.load(model_path, map_location=device, weights_only=False)

        tokenizer = TriangleTokenizer(
            entities=model_dict["tokenizer"]["entities"],
            special_tokens=model_dict["tokenizer"]["special_tokens"]
        )

        model = TransformerPolicy(
            vocab_size=model_dict["config"]["vocab_size"],
            d_model=model_dict["config"]["d_model"],
            n_layer=model_dict["config"]["n_layer"],
            n_head=model_dict["config"]["n_head"],
            dim_ff=model_dict["config"]["dim_ff"],
            dropout=model_dict["config"]["dropout"],
            max_len=model_dict["config"]["max_len"], # 16
            tie_weights=True
        )
        model.load_state_dict(model_dict["model"])
        model = model.to(device)
        model.eval()

        return model, tokenizer
    
    
    def canonicalize(self, seq):
        # The seq looks like "<a_i><a_j><sep><a_j><a_k><sep><a_k><a_i>"
        # Convert it such that i,j,k are sorted
        # Return the same format
        seq = seq.split("<sep>")
        try:
            i = int(seq[0].split("<a_")[1].split(">")[0])
            j = int(seq[1].split("<a_")[1].split(">")[0])
            k = int(seq[2].split("<a_")[1].split(">")[0])
            smallest = min(i, j, k)
            largest = max(i, j, k)
            middle = i + j + k - smallest - largest
            return "<a_{}><a_{}><sep><a_{}><a_{}><sep><a_{}><a_{}>".format(smallest, middle, middle, largest, largest, smallest)
        except:
            print("Failed for canonicalizing:", seq)
            return seq
    
    def eval_model(self, num_samples=100):
        """Evaluate the model by generating samples and computing metrics""" 
        print(f"Generating {num_samples} samples for evaluation...")
        
        # Generate samples
        all_items = []
        for i in tqdm(range(num_samples), desc="Generating samples"):
            graph_id = i % len(self.edges_data)  # Cycle through available graphs
            prompt = f"{graph_id} tri: "
            
            try:
                output_ids = generate(self.model, self.tokenizer, max_new_tokens=9, prompt=prompt)
                output_text = self.tokenizer.decode(output_ids)
                # print(output_text)
                
                # Parse the generated sequence
                vertices = parse_triangle_sequence(output_ids, self.tokenizer)
                
                # Check validity of the generated triangle
                is_valid = False
                if vertices and len(vertices) == 3:
                    edges = self.edges_data[graph_id] if graph_id < len(self.edges_data) else self.edges_data[0]
                    is_valid = is_valid_triangle(vertices, edges)
                
                # Create evaluation item
                item = {
                    "type": "test",
                    "model_output": output_text,
                    "target_text": output_text,  # For compatibility with eval_items
                    "graph_id": graph_id,
                    "vertices": vertices,
                    "is_valid": is_valid,
                    "prompt": prompt
                }
                all_items.append(item)
                
            except Exception as e:
                print(f"Error generating sample {i}: {e}")
                # Add empty item for failed generation
                item = {
                    "type": "test", 
                    "model_output": "",
                    "target_text": "",
                    "graph_id": graph_id,
                    "vertices": None,
                    "is_valid": False,
                    "prompt": prompt
                }
                all_items.append(item)

        # Evaluate the generated items
        acc = self.eval_items(all_items)
        
        # Calculate final scores
        test_predicted_answers = acc.pop("test_predicted_answers", [])
        test_predicted_answers_diversity = acc.pop("test_predicted_answers_diversity", [])
        
        scores = [(t, round(sum(acc[t])/len(acc[t]), 3)) for t in acc]
        n_samples = len(test_predicted_answers)
        
        if n_samples > 0:
            test_predicted_answers = [ans for ans in test_predicted_answers if ans != ""]
            test_predicted_answers_diversity = [ans for ans in test_predicted_answers_diversity if ans != ""]
            scores.append(("test_creativity_score", len(set(test_predicted_answers)) / n_samples))
            scores.append(("test_diversity_score", len(set(test_predicted_answers_diversity)) / n_samples))
        
        # Calculate validity statistics
        valid_count = sum(1 for item in all_items if item.get("is_valid", False))
        validity_rate = valid_count / len(all_items) if all_items else 0
        
        # Store detailed results for JSON export
        self.detailed_results = {
            "model_path": self.model_path,
            "num_samples": num_samples,
            "device": self.device,
            "scores": dict(scores),
            "samples": all_items,
            "summary": {
                "total_samples": n_samples,
                "test_seen_count": sum(acc.get("test_seen_score", [])),
                "test_unseen_count": sum(acc.get("test_unseen_score", [])),
                "test_creativity_count": len(set(test_predicted_answers)),
                "test_diversity_count": len(set(test_predicted_answers_diversity)),
                "valid_count": valid_count,
                "validity_rate": round(validity_rate, 3)
            }
        }
        
        return [("model_evaluation", scores, valid_count)]

    def eval_items(self, all_items):
        acc = dict()   # maps each type of example to the corresponding list of eval results
        for item in all_items:
            if 'type' not in item:
                t = 'test'
            else:
                t = item['type']
            
            if "model_output" in item:
                pred, gold = item["model_output"], item["target_text"]
            else:
                pred, gold = item["model output"], item["target text"]

            #if t == "test":
            #    print(pred)
            #    print(gold)
            #    print()
            if t == "train":
                if "train_memorization_score" not in acc:
                    acc["train_memorization_score"] = []
                acc["train_memorization_score"].append(self.eval_res(pred, gold))
            elif t == "test":
                if "test_seen_score" not in acc:
                    acc["test_seen_score"] = []
                if "test_unseen_score" not in acc:
                    acc["test_unseen_score"] = []
                if "test_predicted_answers" not in acc:
                    acc["test_predicted_answers"] = []
                if "test_predicted_answers_diversity" not in acc:
                    acc["test_predicted_answers_diversity"] = []
                graph_id = item["graph_id"]
                # seen score
                seen_score = self.get_seen_score(pred, graph_id)
                acc["test_seen_score"].append(seen_score)
                # validity score
                validity_score = self.get_validity_score(pred, graph_id)
                if seen_score == 1:
                    if validity_score != 1:
                        breakpoint()
                        raise ValueError("Seen triangle should be valid!")
                    unseen_score = 0
                else:
                    if validity_score == 1:
                        unseen_score = 1
                    else:
                        unseen_score = 0
                acc["test_unseen_score"].append(unseen_score)
                acc["test_predicted_answers"].append(pred.split("tri:")[1] if unseen_score == 1 else "")  # Store the predicted answer
                acc["test_predicted_answers_diversity"].append(pred.split("tri:")[1] if validity_score == 1 else "")  # Store the predicted answer
            else:
                raise ValueError(f"Unknown type: {t}")
        return acc
    
    def get_seen_score(self, pred, graph_id):
        if pred.count("</a>") != 1:
            print("Incorrect format:", pred)
            return 0
        pred = pred.split("tri:")[1]
        pred = self.canonicalize(pred)
        # print(pred)
        #print(self.train_sequences[0])
        #print()
        # return int(pred in self.train_sequences)

        # NOTE: only return seen if seen in training data for current graph
        for seq_idx, seq in self.train_sequences:
            if pred == seq:
                if seq_idx == str(graph_id):
                    return 1
        return 0


    
    def get_validity_score(self, pred, graph_id):
        if pred.count("</a>") != 1:
            print("Incorrect format:", pred)
            return 0
        pred = pred.split("tri:")[1]
        pred = pred.split("</a>")[0]
        try:
            ab, bc, ca = pred.split("<sep>")
            a, b = ab.split("><")
            b, c = bc.split("><")
        except:
            print("Failed for parsing:", pred)
            return 0
        a = complete(a)
        b = complete(b)
        c = complete(c)

        edges = self.edges_data[graph_id]

        try:
            return int(
                b in edges[a] and c in edges[b] and a in edges[c]
            )
        except:
            print("Failed for dictionary:", a)
            return 0    

    def eval_res(self, a, b):
        a = a.replace(" ", "")
        b = b.replace(" ", "")
        assert b.count("</a>") == 1
        b = b.split("</a>")[0]
        a = a.split("</a>")[0]
        
        return int(a==b)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", default=MODEL_PATH, type=str, help="Path to pretrained model file.")
    parser.add_argument("--dataset", default="triangle.10", type=str, help="Dataset name.")
    parser.add_argument("--data_dir", default=DATA_ROOT, type=str, help="Data dir.")
    parser.add_argument("--num_samples", default=100, type=int, help="Number of samples to generate for evaluation.")
    parser.add_argument("--device", default="auto", type=str, help="Device to use (auto, cpu, cuda).")
    parser.add_argument("--seed", default=42, type=int, help="Seed for reproducibility.")
    args = parser.parse_args()

    set_seed(args.seed)

    # Determine device
    if args.device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = args.device
    
    print(f"Using device: {device}")
    print(f"Loading model from: {args.model_path}")

    evaluator = Evaluator(args.model_path, args.dataset, args.data_dir, device)
    scores_dict = evaluator.eval_model(args.num_samples)

    print(scores_dict)
    
if __name__ == '__main__':
    main()
