import random
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any
import numpy as np
from datasets import load_dataset
from transformers import PreTrainedTokenizer
from tqdm import tqdm

from .config import ExistenceConfig

@dataclass
class Slot:

    seq_id: int
    position: int

    true_token_id: int
    full_sequence: List[int]

    swap_right_tokens: Optional[List[int]] = None
    shuffle_right_tokens: Optional[List[int]] = None

    def get_left_context(self, L: int) -> List[int]:

        start = max(0, self.position - L)
        return self.full_sequence[start:self.position]

    def get_right_context(self, R: int, condition: str = "real") -> List[int]:

        if condition == "real":
            end = min(len(self.full_sequence), self.position + 1 + R)
            return self.full_sequence[self.position + 1:end]
        elif condition == "swap":
            if self.swap_right_tokens is None:
                raise ValueError("Swap control not initialized")
            return self.swap_right_tokens[:R]
        elif condition == "shuffle":
            if self.shuffle_right_tokens is None:
                raise ValueError("Shuffle control not initialized")
            return self.shuffle_right_tokens[:R]
        else:
            raise ValueError(f"Unknown condition: {condition}")

def load_and_tokenize_dataset(
    dataset_name: str,
    tokenizer: PreTrainedTokenizer,
    config: ExistenceConfig
) -> List[List[int]]:

    print(f"[Data] Loading dataset: {dataset_name}")

    if dataset_name == "wikitext":
        dataset = load_dataset(
            "Salesforce/wikitext",
            config.wikitext_config,
            split=config.wikitext_split
        )
    elif dataset_name == "openwebtext":
        try:
            dataset = load_dataset(
                "stas/openwebtext-10k",
                split="train"
            )
            print("[Data] Loaded OpenWebText-10k")
        except Exception:
            try:
                dataset = load_dataset(
                    "openwebtext",
                    split=f"{config.openwebtext_split}[:10000]",
                    trust_remote_code=True
                )
                print("[Data] Loaded OpenWebText subset")
            except Exception as e:
                print(f"[Data] OpenWebText unavailable ({e}), falling back to C4")
                dataset = load_dataset(
                    "allenai/c4",
                    "en",
                    split=f"{config.openwebtext_split}[:10000]",
                    trust_remote_code=True
                )
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    print(f"[Data] Loaded {len(dataset)} examples")

    sequences = []
    for item in tqdm(dataset, desc="Tokenizing"):
        text = item["text"]
        if not text or len(text.strip()) == 0:
            continue

        tokens = tokenizer.encode(text, add_special_tokens=False)

        if len(tokens) >= config.min_seq_length:
            sequences.append(tokens)

    print(f"[Data] {len(sequences)} sequences pass length filter (>= {config.min_seq_length} tokens)")

    return sequences

def sample_slots(
    sequences: List[List[int]],
    config: ExistenceConfig,
    rng: random.Random
) -> List[Slot]:

    valid_positions = []

    for seq_id, seq in enumerate(sequences):
        start_pos = config.l_max
        end_pos = len(seq) - config.r_max - 1

        for pos in range(start_pos, end_pos + 1):
            valid_positions.append((seq_id, pos))

    print(f"[Data] Total valid positions: {len(valid_positions)}")

    n_to_sample = min(config.n_slots, len(valid_positions))
    sampled = rng.sample(valid_positions, n_to_sample)

    slots = []
    for seq_id, pos in sampled:
        seq = sequences[seq_id]
        slot = Slot(
            seq_id=seq_id,
            position=pos,
            true_token_id=seq[pos],
            full_sequence=seq
        )
        slots.append(slot)

    print(f"[Data] Sampled {len(slots)} slots")

    return slots

def generate_controls(
    slots: List[Slot],
    sequences: List[List[int]],
    config: ExistenceConfig,
    rng: random.Random
) -> None:

    print("[Data] Generating control contexts...")

    r_max = config.r_max

    swap_pool = []
    for seq_id, seq in enumerate(sequences):
        for pos in range(len(seq) - r_max):
            swap_pool.append((seq_id, pos))

    for slot in tqdm(slots, desc="Generating controls"):
        original_right = slot.get_right_context(r_max, "real")

        attempts = 0
        while attempts < 100:
            swap_seq_id, swap_pos = rng.choice(swap_pool)
            if swap_seq_id != slot.seq_id:
                swap_seq = sequences[swap_seq_id]
                swap_right = swap_seq[swap_pos + 1 : swap_pos + 1 + r_max]
                if len(swap_right) == r_max:
                    slot.swap_right_tokens = swap_right
                    break
            attempts += 1
        else:
            swap_seq_id, swap_pos = rng.choice(swap_pool)
            swap_seq = sequences[swap_seq_id]
            slot.swap_right_tokens = swap_seq[swap_pos + 1 : swap_pos + 1 + r_max]

        shuffled = list(original_right)
        rng.shuffle(shuffled)
        slot.shuffle_right_tokens = shuffled

    print("[Data] Control contexts generated")

def prepare_dataset(
    dataset_name: str,
    tokenizer: PreTrainedTokenizer,
    config: ExistenceConfig
) -> Tuple[List[Slot], List[List[int]]]:

    rng = random.Random(config.seed)
    np.random.seed(config.seed)

    sequences = load_and_tokenize_dataset(dataset_name, tokenizer, config)

    slots = sample_slots(sequences, config, rng)

    generate_controls(slots, sequences, config, rng)

    return slots, sequences

def compute_token_frequencies(slots: List[Slot]) -> Dict[int, int]:

    freq = {}
    for slot in slots:
        token_id = slot.true_token_id
        freq[token_id] = freq.get(token_id, 0) + 1
    return freq

def compute_token_ranks(freq: Dict[int, int]) -> Dict[int, int]:

    sorted_tokens = sorted(freq.items(), key=lambda x: -x[1])

    ranks = {}
    for rank, (token_id, _) in enumerate(sorted_tokens, start=1):
        ranks[token_id] = rank

    return ranks
