import os
import math
from typing import Optional, Dict, Any, Tuple, List

import json
import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
from gymnasium import spaces
import random
import string
import wandb

MODEL_PATH = 'out/pretrain_policy.pt'
FINE_TUNED_MODEL_PATH = 'out/ppo_finetune_policy.pt'
HIDDEN_DIM = 256
MAX_LEN = 11  # 9 + 2 for prefix
META_FIXED_SEED = 42

# ----------------------
# Dataset utils
# ----------------------
DATA_ROOT = "data"
HASH_STR_LEN = 10 # length of hash strings in dataset
T_TRIANGLES = 6  # default number of triangles per item
SPECIAL_TOKENS = ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]

DEBUG = os.environ.get("DEBUG", "False").lower() == "true"

def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # Ensure deterministic behavior
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def dataset_dir(data_root: str, hash_len: int, T_val: int) -> str:
    name = f"triangle.{hash_len}"
    if T_val != 6:
        name += f".T{T_val}"
    return os.path.join(data_root, name)

def load_json(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def extract_hash_or_empty(input_text: str) -> str:
    """
    For triangle items, input_text format is "<hash> tri: ".
    For edges, input_text is "<hash> edge: " -> return "<hash>".
    """
    marker = " tri: "
    if marker in input_text:
        return input_text.split(marker)[0]
    return ""  # edges have no hash


class TransformerPolicy(nn.Module):
    """
    Simple GPT decoder-only Transformer.
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 256,
        n_layer: int = 4,
        n_head: int = 8,
        dim_ff: int = 4 * 256,
        dropout: float = 0.1,
        max_len: int = 64,
        tie_weights: bool = True,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_head,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.blocks = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        if tie_weights:
            self.head.weight = self.tok_emb.weight

        self.register_buffer(
            "causal_mask",
            torch.triu(torch.ones(max_len, max_len, dtype=torch.bool), diagonal=1),
            persistent=False,
        )

    def forward(self, input_ids: torch.Tensor, attn_mask: torch.Tensor | None = None):
        B, T = input_ids.shape
        if T > self.max_len:
            raise ValueError(f"Sequence length {T} exceeds max_len {self.max_len}")

        pos = torch.arange(T, device=input_ids.device).unsqueeze(0)  # [1, T]
        x = self.tok_emb(input_ids) + self.pos_emb(pos)

        causal = self.causal_mask[:T, :T]  # [T, T], True above diagonal
        if attn_mask is not None:
            pad_mask = ~attn_mask.bool()  # [B, T]
        else:
            pad_mask = None

        x = self.blocks(x, mask=causal, src_key_padding_mask=pad_mask)
        x = self.ln_f(x)
        logits = self.head(x)  # [B, T, vocab]
        return logits


class TriangleTokenizer:
    def __init__(self, entities, special_tokens):
        self.entities = list(entities)
        self.special_tokens = list(special_tokens)
        # Add special tokens for tri: and edge:
        self.added_tokens = ["tri:", "edge:"] + [str(i) for i in range(10)]
        self.tokens = self.entities + self.special_tokens + self.added_tokens
        self.tok2id = {t:i for i,t in enumerate(self.tokens)}
        self.id2tok = {i:t for t,i in self.tok2id.items()}

        self.pad_id = self.tok2id["<mask>"]
        self.bos_id = self.tok2id["<a>"]
        self.eos_id = self.tok2id["</a>"]

    @property
    def vocab_size(self): return len(self.tokens)

    def encode(self, suffix):
        pieces = []
        i = 0
        while i < len(suffix):
            # see if it's a special token
            for tok in self.special_tokens + self.added_tokens:
                if suffix.startswith(tok, i):
                    pieces.append(tok)
                    i += len(tok)
                    break
            else:
                # entity tokens are formatted <a_123>
                if suffix[i] == "<":
                    j = suffix.find(">", i)
                    pieces.append(suffix[i:j+1])
                    i = j+1
                else:
                    # whitespace or accidental char
                    i += 1
        return [self.tok2id[p] for p in pieces if p in self.tok2id]

    def decode(self, token_ids):
        """Decode token IDs back to text string."""
        assert isinstance(token_ids, list), "token_ids must be a list"
        tokens = [self.id2tok[token_id] for token_id in token_ids]
        return "".join(tokens)

    def hash_onehot(self, hash_str):
        alphabet = string.ascii_lowercase + string.digits
        A = len(alphabet)
        vec = np.zeros((len(hash_str), A), dtype=np.float32)
        for i,ch in enumerate(hash_str):
            k = alphabet.find(ch)
            if k >= 0: vec[i, k] = 1.0
        return vec.reshape(-1)  # flatten


class FixedGraphTriangleEnvironment(gym.Env):
    """Environment with single graph for sampling triangles"""
    def __init__(self, tokenizer, graph_idx=0, device="cuda"):
        super().__init__()
        
        self.tokenizer = tokenizer
        self.graph_idx = graph_idx
        self.max_sequence_length = 9+2 # prefix + triangle tokens generated
        self.device = device
        
        self.action_space = spaces.Discrete(tokenizer.vocab_size)
        
        self.observation_space = spaces.Box(
            low=0, high=tokenizer.vocab_size-1, 
            shape=(self.max_sequence_length,), dtype=np.int64
        )

        # Load selected graph
        ddir = dataset_dir(DATA_ROOT, HASH_STR_LEN, T_TRIANGLES)
        graph_path = os.path.join(ddir, f"edges_{graph_idx}.json")
        if not os.path.exists(graph_path):
            raise FileNotFoundError(f"Could not find {graph_path}")
        self.edges = load_json(graph_path)

        print(f"Loaded graph {graph_idx} with {len(self.edges)} nodes")
    
    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        initial_obs = f"{self.graph_idx} tri: "
        
        # Tokenize the initial observation
        initial_tokens = self.tokenizer.encode(initial_obs)
        self.current_sequence = initial_tokens
        self.step_count = 0
        self.done = False

        self.seed = seed

        return self._get_observation(), {"seed": seed}
    
    def step(self, action):
        if self.done:
            return self._get_observation(), 0.0, True, False, {"seed": self.seed}
        
        # Add action (token) to sequence
        self.current_sequence.append(action)
        self.step_count += 1
        
        # Check if sequence is complete or too long
        # if (action == self.tokenizer.eos_id or len(self.current_sequence) >= self.max_sequence_length):
        if len(self.current_sequence) >= self.max_sequence_length: # prevent truncating
            self.done = True
            reward = self._calculate_reward()
        else:
            reward = 0.0
        
        return self._get_observation(), reward, self.done, False, {"seed": self.seed}
    
    def _get_observation(self):
        # Pad sequence to max length
        sequence = self.current_sequence + [self.tokenizer.pad_id] * (self.max_sequence_length - len(self.current_sequence))
        
        return np.array(sequence, dtype=np.int64)
    
    def _calculate_reward(self):
        if not self.done:
            return 0.0
            
        triangle_vertices = parse_triangle_sequence(self.current_sequence, self.tokenizer)
        
        # Check if the triangle is valid in the current graph
        if is_valid_triangle(triangle_vertices, self.edges):
            return 1.0  # Reward for valid triangle
        return 0.0


def build_ar_sequences(train_items, tokenizer, max_len=256):
    """
    Input sequences: <hash> tri: a b <sep> b c <sep> c a </a>
    Output sequences: a b <sep> b c <sep> c a </a>
    Loss mask: 0 for input prefix, 1 for generation part
    """
    seqs_in, seqs_lab, attn, loss_mask = [], [], [], []

    for item in train_items:
        target_seq = tokenizer.encode(item["target_text"])
        input_seq = tokenizer.encode(item["input_text"])
        input_prefix_len = len(input_seq)

        # input and label sequences 
        inp = target_seq[:-1]
        lab = target_seq[1:]
        
        # loss mask: 0 for prompt, 1 for generation part
        # Make sure we don't exceed the actual sequence length 
        actual_len = len(lab)
        num_mask_zeros = input_prefix_len - 1
        loss_mask_seq = [0] * min(num_mask_zeros, actual_len) + [1] * max(0, actual_len - num_mask_zeros)
        
        # pad to max_len
        T = len(inp)
        if T < max_len:
            inp = inp + [tokenizer.pad_id] * (max_len - T)
            lab = lab + [tokenizer.pad_id] * (max_len - T)
            attn_mask = [1] * T + [0] * (max_len - T)
            loss_mask_seq += [0] * (max_len - T)  # pad loss mask with 0s
        else:
            inp = inp[:max_len]
            lab = lab[:max_len]
            attn_mask = [1] * max_len
            loss_mask_seq = loss_mask_seq[:max_len]

        seqs_in.append(inp)
        seqs_lab.append(lab)
        attn.append(attn_mask)
        loss_mask.append(loss_mask_seq)

    X = torch.tensor(seqs_in, dtype=torch.long)
    Y = torch.tensor(seqs_lab, dtype=torch.long)
    M = torch.tensor(attn, dtype=torch.bool)
    L = torch.tensor(loss_mask, dtype=torch.bool)
    return X, Y, M, L

def parse_triangle_sequence(sequence, tokenizer):
    """
    Extract triangle vertices from a token sequence
    
    Input format: token ids for <a_1> <a_2> <sep> <a_2> <a_3> <sep> <a_3> <a_1> </a>
    Returns: list of tokens for three vertices [<a_1>, <a_2>, <a_3>] or None if invalid format
    """
    # Convert token IDs to tokens
    tokens = [tokenizer.id2tok.get(token_id) for token_id in sequence]
    
    # Remove padding tokens
    tokens = [t for t in tokens if t != tokenizer.id2tok[tokenizer.pad_id]]

    # Remove added tokens
    tokens = [t for t in tokens if t not in tokenizer.added_tokens]
    
    if len(tokens) < 9:
        print("not proper triangle format")
        return None
    
    sep_token = "<sep>"
    end_token = "</a>"
    
    # Find the positions of <sep> tokens
    sep_positions = [i for i, token in enumerate(tokens) if token == sep_token]
    if len(sep_positions) != 2:
        if DEBUG: print("no sep pos")
        return None
    
    # Find the position of </a> token
    end_pos = None
    for i, token in enumerate(tokens):
        if token == end_token:
            end_pos = i
            break
    
    if end_pos is None:
        if DEBUG: print("no end pos")
        return None
    
    # Extract vertices: <a_1> <a_2> <sep> <a_2> <a_3> <sep> <a_3> <a_1> </a>
    if (sep_positions[0] == 2 and sep_positions[1] == 5 and end_pos == 8):
        # check if each edge is distinct and the set of all nodes is exactly 3
        nodes = set()
        for i in range(len(tokens)-1): # exclude end token
            if i not in sep_positions:
                nodes.add(tokens[i])
        if len(nodes) != 3:
            if DEBUG: print("no 3 nodes")
            return None
        else:
            return list(nodes)
    
    return None



def is_valid_triangle(vertices, edges):
    """
    Check if three vertices form a valid triangle in the graph.
    
    Args:
        vertices: list of three vertex names [a, b, c]
        edges: dictionary mapping vertices to their neighbors
    
    Returns:
        bool: True if the three vertices form a triangle (all edges exist)
    """
    assert isinstance(edges, dict), "edges must be a dictionary"

    if not vertices or len(vertices) != 3:
        return False
    
    a, b, c = vertices
    
    # Check if all three vertices exist in the graph
    if a not in edges or b not in edges or c not in edges:
        return False
    
    # Check if all three edges exist: (a,b), (b,c), (c,a)
    return (b in edges[a] and c in edges[b] and a in edges[c])

def save_model(model, tokenizer, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        "model": model.state_dict(),
        "config": {
            "vocab_size": tokenizer.vocab_size,
            "d_model": 256, "n_layer": 4, "n_head": 8, "dim_ff": 1024, "dropout": 0.1, "max_len": model.max_len
        },
        "tokenizer": {
            "entities": tokenizer.entities,
            "special_tokens": tokenizer.special_tokens
        },
    }, path)
    print(f"Model saved to {path}")


class FixedSeedWrapper(gym.Wrapper):
    def __init__(self, env, fixed_seed: Optional[int] = None, allow_override: bool = True):
        super().__init__(env)

        if hasattr(env, 'graph_idx'):  # for sampling, TODO move to FixedGraphTriangleEnvironment
            self.graph_idx = env.graph_idx
        if hasattr(env, 'edges'):
            self.edges = env.edges

        self.fixed_seed = fixed_seed
        self.allow_override = allow_override

    def reset(self, *, seed=None, **kwargs):
        # Use caller-provided seed when available (for eval), otherwise fall back to fixed_seed
        if seed is None or not self.allow_override:
            assert self.fixed_seed is not None, "Fixed seed is not set!!!"
            seed_to_use = int(self.fixed_seed) if self.fixed_seed is not None else None
        else:
            seed_to_use = int(seed)
        return self.env.reset(seed=seed_to_use, **kwargs)

def make_env(tokenizer, graph_idx, device):
    env = FixedGraphTriangleEnvironment(tokenizer, graph_idx, device=device)
# def make_env(tokenizer, selected_graphs, device, data_dir="rollthedice/triangle_discovery/data", dataset="triangle.10"):
#     env = TriangleEnvironment(tokenizer, selected_graphs, device=device, data_dir=data_dir, dataset=dataset)
    env = FixedSeedWrapper(env, fixed_seed=META_FIXED_SEED, allow_override=True)
    return env


def validate_triangle_generation(policy, tokenizer, envs, device="cpu", num_samples=10, wandb_run=None, epoch=None):
    """Validate triangle generation accuracy."""
    policy.eval()
    correct_triangles = 0
    total_samples = 0
    
    with torch.no_grad():
        for env in envs: 
            for _ in range(num_samples):
                prompt = f"{env.graph_idx} tri: "
                output_ids = generate(policy, tokenizer, max_new_tokens=9, device=device, prompt=prompt)
                # print(tokenizer.decode(output_ids))
                vertices = parse_triangle_sequence(output_ids, tokenizer)
                
                if vertices and len(vertices) == 3:
                    edges = env.edges
                    if is_valid_triangle(vertices, edges):
                        correct_triangles += 1
                    # else:
                    #     print(f"Invalid triangle for graph {i}: {vertices}")
                else:
                    print(f"Failed to parse triangle for graph {env.graph_idx}: {vertices}")
                
                total_samples += 1
                
    
    accuracy = correct_triangles / total_samples if total_samples > 0 else 0
    print(f"Triangle accuracy: {correct_triangles}/{total_samples} = {accuracy:.2%}")
    
    # Log validation metrics to wandb
    if wandb_run is not None:
        log_dict = {
            "val/accuracy": accuracy,
            "val/correct_triangles": correct_triangles,
            "val/total_samples": total_samples,
        }
        if epoch is not None:
            log_dict["val/epoch"] = epoch
        wandb.log(log_dict)
    
    return accuracy


@torch.no_grad()
def generate(policy, tokenizer, max_new_tokens=16, prompt="0 tri: ", device="cuda", temperature=1.0):
    policy.eval()
    inp = torch.tensor([tokenizer.encode(prompt)], device=device)  # [1, 1]
    # inp = torch.tensor([[tokenizer.bos_id]], device=device)  # [1, 1]
    attn = torch.ones_like(inp, dtype=torch.bool)

    for _ in range(max_new_tokens):
        logits = policy(inp, attn_mask=attn)[:, -1, :]
        if temperature > 0.0:
            logits = logits / temperature
            dist = torch.distributions.Categorical(logits=logits)
            next_id = dist.sample().unsqueeze(0)
        else:
            next_id = torch.argmax(logits, dim=-1, keepdim=True)  # greedy
        inp = torch.cat([inp, next_id], dim=1)
        attn = torch.ones_like(inp, dtype=torch.bool)
        if next_id.item() == tokenizer.eos_id:
            break
    return inp.squeeze(0).tolist()


def _nan_to_none(x):
    """Return None for NaN/inf/None, else a plain float."""
    if x is None:
        return None
    try:
        xf = float(x)
    except (TypeError, ValueError):
        return None
    if math.isnan(xf) or math.isinf(xf):
        return None
    return xf

def init_wandb(run_name: str, config: dict):
    """
    Initializes wandb with a robust offline fallback.
    - Online if WANDB_API_KEY is set / wandb.login() succeeds
    - Otherwise returns None (no logging)
    """
    try:
        wandb.login()
        print("[wandb] Login successful. Initializing run.")
        os.environ["WANDB_MODE"] = "online"
        run = wandb.init(
            project="polychromic-triangle-discovery",
            name=run_name,
            config=config,
            resume=None,
            tags=["pretrain", "triangle-discovery", "transformer"],
        )
        return run
    except Exception as e:
        print(f"[wandb] init failed: {e}")
        return None


def load_model_and_tokenizer(model_path=MODEL_PATH, device="cpu"):
    model_dict = torch.load(model_path, map_location=device)

    tokenizer = TriangleTokenizer(
        entities=model_dict["tokenizer"]["entities"],
        special_tokens=model_dict["tokenizer"]["special_tokens"]
    )

    model = TransformerPolicy(
        vocab_size=model_dict["config"]["vocab_size"],
        d_model=model_dict["config"]["d_model"],
        n_layer=model_dict["config"]["n_layer"],
        n_head=model_dict["config"]["n_head"],
        dim_ff=model_dict["config"]["dim_ff"],
        dropout=model_dict["config"]["dropout"],
        max_len=model_dict["config"]["max_len"],
        tie_weights=True
    )
    model.load_state_dict(model_dict["model"])
    model = model.to(device)

    return model, tokenizer
