import random
from typing import List

import numpy as np
import torch
from absl import logging
from lark import Transformer
from torch.utils.data import IterableDataset, get_worker_info

from td.environments import Environment, environments
from td.grammar import Grammar
from td.learning.constrained_decoding import (
    DecoderState,
)
from td.learning.gpt import TreeDiffusion
from td.learning.tokenizer import Tokenizer
from td.samplers import ConstrainedRandomSampler


class CSG2DACompositor(Transformer):
    def __init__(
        self,
        visit_tokens: bool = True,
    ) -> None:
        super().__init__(visit_tokens)

    def quad(self, children):
        x, y, w, h, angle_degrees = children

        return f"(Quad {x} {y} {w} {h} {angle_degrees})"

    def circle(self, children):
        r, x, y = children

        return f"(Circle {r} {x} {y})"

    def binop(self, children):
        op, left, right = children

        res = []

        if isinstance(left, list):
            res.extend(left)
        else:
            res.append(left)

        res.append(op)

        if isinstance(right, list):
            res.extend(right)
        else:
            res.append(right)

        return res

    def add(self, children):
        return "+"

    def subtract(self, children):
        return "-"

    def s(self, children):
        return children[0]

    def zero(self, _):
        return 0

    def one(self, _):
        return 1

    def two(self, _):
        return 2

    def three(self, _):
        return 3

    def four(self, _):
        return 4

    def five(self, _):
        return 5

    def six(self, _):
        return 6

    def seven(self, _):
        return 7

    def eight(self, _):
        return 8

    def nine(self, _):
        return 9

    def ten(self, _):
        return "A"

    def eleven(self, _):
        return "B"

    def twelve(self, _):
        return "C"

    def thirteen(self, _):
        return "D"

    def fourteen(self, _):
        return "E"

    def fifteen(self, _):
        return "F"

    def zerodeg(self, _):
        return "G"

    def onedeg(self, _):
        return "H"

    def twodeg(self, _):
        return "I"

    def threedeg(self, _):
        return "J"

    def fourdeg(self, _):
        return "K"

    def fivedeg(self, _):
        return "L"

    def sixdeg(self, _):
        return "M"

    def sevendeg(self, _):
        return "N"


flattener = CSG2DACompositor()


def flatten(env, expression):
    return flattener.transform(env.grammar.parse(expression))


def unflatten(terms):
    if not terms:
        return ""

    # Start with the first element
    result = terms[0]

    # Process the remaining elements in pairs
    i = 1
    while i < len(terms) - 1:
        operator = terms[i]
        operand = terms[i + 1]
        result = f"({operator} {result} {operand})"
        i += 2

    return result


class ComposeFlowDataset(IterableDataset):
    def __init__(
        self,
        batch_size,
        env_name,
        max_sequence_length,
        min_primitives,
        max_primitives,
        target_observation,
    ):
        assert env_name == "csg2da", "Only CSG2DA is supported"

        self._env_name = env_name
        self._batch_size = batch_size
        self._max_sequence_length = max_sequence_length
        self._min_primitives = min_primitives
        self._max_primitives = max_primitives
        self._target_observation = target_observation

    def _produce_batch(self):
        def sample_fn():
            return self._sampler.sample(
                self._env.grammar.start_symbol,
                min_primitives=self._min_primitives,
                max_primitives=self._max_primitives,
            )

        def convert_expression_to_training_data(target_expression):
            target_image = (
                self._env.compile(target_expression)
                if not self._target_observation
                else self._env.compile_observation(target_expression)
            )
            flattened = ["+"] + flatten(self._env, target_expression)
            step = random.randint(0, len(flattened) // 2)
            context = flattened[: step * 2]
            rest = flattened[step * 2 :]
            next_shape = rest[:2]
            if not len(next_shape):
                next_shape = ["<EOS>"]

            current_image = (
                self._env.compile(unflatten(context[1:]))
                if len(context) > 1
                else np.zeros(self._env.compiled_shape)
            )
            context_string = " ".join(context)
            next_shape_string = " ".join(next_shape)

            return target_image, current_image, context_string, next_shape_string

        def tokenize(context_string, next_shape_string):
            context_tokens = self._tokenizer._tokenize_one(context_string)
            next_shape_tokens = self._tokenizer._tokenize_one(next_shape_string)

            all_tokens = (
                context_tokens + [self._tokenizer.sos_token] + next_shape_tokens
            )

            if len(all_tokens) >= self._tokenizer.max_sequence_length:
                raise ValueError(
                    f"Tokenized sequence too long: {len(all_tokens)} tokens"
                )

            context_mask = [0] * (len(context_tokens) + 1) + [1] * len(
                next_shape_tokens
            )
            context_mask += [0] * (
                self._tokenizer.max_sequence_length - len(context_mask)
            )

            all_tokens += [self._tokenizer.pad_token] * (
                self._tokenizer.max_sequence_length - len(all_tokens)
            )

            return all_tokens, context_mask

        def sample_batch():
            target_expression = self._env.sample_non_empty(sample_fn)
            target_image, current_image, context_string, next_shape_string = (
                convert_expression_to_training_data(target_expression)
            )
            tokens, mask = tokenize(context_string, next_shape_string)
            return target_image, current_image, tokens, mask

        batch = []

        while len(batch) < self._batch_size:
            try:
                batch.append(sample_batch())
            except Exception as e:
                logging.error("Failed to sample batch: %s", e)
                continue

        target_images, current_images, tokenized, context_tokens_mask = zip(*batch)

        return (
            np.array(tokenized),
            np.array(context_tokens_mask),
            np.array(target_images).transpose(0, 3, 1, 2),
            np.array(current_images).transpose(0, 3, 1, 2),
        )

    def __iter__(self):
        worker_info = get_worker_info()

        if worker_info is not None:
            np.random.seed(worker_info.id)
            random.seed(worker_info.id)

        self._env: Environment = environments[self._env_name]()
        self._sampler = ConstrainedRandomSampler(self._env.grammar)
        self._tokenizer = Tokenizer(
            self._env.grammar,
            max_token_length=self._max_sequence_length,
            max_sequence_length=self._max_sequence_length,
        )

        while True:
            yield self._produce_batch()


_fake_grammar_spec = r"""
s: opshape

shape: circle | quad
opshape: op " " shape

// Number quantized 0 to 16.
number: "0" -> zero | "1" -> one | "2" -> two | "3" -> three | "4" -> four | "5" -> five | "6" -> six | "7" -> seven | "8" -> eight | "9" -> nine | "A" -> ten | "B" -> eleven | "C" -> twelve | "D" -> thirteen | "E" -> fourteen | "F" -> fifteen

// angles [0, 45, 90, 135, 180, 225, 270, 315]
angle: "G" -> zerodeg | "H" -> onedeg | "I" -> twodeg | "J" -> threedeg | "K" -> fourdeg | "L" -> fivedeg | "M" -> sixdeg | "N" -> sevendeg

// (Circle radius x y)
circle: "(" "Circle" " " number " " number " " number ")"

// (Quad x0 y0 x1 y1 x2 y2 x3 y3)
// quad: "(" "Quad" " " number " " number " " number " " number " " number " " number " " number " " number ")"

// (Quad x y w h angle)
quad: "(" "Quad" " " number " " number " " number " " number " " angle ")"

op: "+" -> add | "-" -> subtract

%ignore /[\t\n\f\r]+/ 
"""


fake_grammar = Grammar(
    _fake_grammar_spec,
    start="s",
    primitives=["circle", "quad"],
)


def sample_model_compose_kv(
    model: TreeDiffusion,
    env: Environment,
    tokenizer: Tokenizer,
    current_expressions: List[List[str]],
    target_images,
    temperature=1.0,
) -> List[str]:
    with torch.inference_mode():
        device = next(model.parameters()).device

        current_full_expressions = [" ".join(x) for x in current_expressions]

        current_images = []
        for x in current_expressions:
            try:
                if len(x[1:]):
                    current_images.append(env.compile(unflatten(x[1:])))
                else:
                    current_images.append(np.zeros(env.compiled_shape))
            except Exception as e:
                logging.error("Failed to compile expression: %s", e)
                current_images.append(np.zeros(env.compiled_shape))

        current_images = (
            torch.tensor(np.array(current_images))
            .float()
            .permute(0, 3, 1, 2)
            .to(device)
        )

        # current_images = (
        #     torch.tensor(
        #         np.array(
        #             [
        #                 env.compile(unflatten(x[1:]))
        #                 if len(x[1:])
        #                 else np.zeros(env.compiled_shape)
        #                 for x in current_expressions
        #             ]
        #         )
        #     )
        #     .float()
        #     .permute(0, 3, 1, 2)
        #     .to(device)
        # )

        image_embeddings = model.image_embeddings(target_images, current_images)
        context_tokens = [tokenizer._tokenize_one(x) for x in current_full_expressions]
        start_decoding_positions = [len(x) for x in context_tokens]
        # Pad to max length.
        max_length = tokenizer.max_token_length
        current_tokens = torch.tensor(
            [
                x
                + [tokenizer.sos_token]
                + [tokenizer.pad_token] * (max_length - len(x) - 1)
                for x in context_tokens
            ]
        ).to(device)
        start_decoding_positions = torch.tensor(start_decoding_positions)
        decode_states = [
            DecoderState(fake_grammar, tokenizer, "") for x in current_expressions
        ]
        for decode_state in decode_states:
            decode_state.force_set_rule()
            decode_state._valid_tokens_mask[tokenizer.eos_token] = 1

        current_position = 0
        k_cache = None
        v_cache = None

        while (
            not all(x._decode_state == DecoderState.States.END for x in decode_states)
            and current_position < max_length - 1
        ):
            logits, k_cache, v_cache = model.transformer(
                current_tokens[:, [current_position]],
                extra_emb=image_embeddings,
                k_cache=k_cache,
                v_cache=v_cache,
                start_idx=current_position,
            )
            logits = logits.cpu()

            logits_for_positions = logits[:, 0, :]
            logits_for_positions = logits_for_positions / temperature
            masks = np.stack([x.mask for x in decode_states])
            decode_mask = torch.tensor(masks).bool()
            logits_for_positions = torch.where(
                decode_mask, logits_for_positions, -torch.tensor(float("inf"))
            )
            probs = torch.nn.functional.softmax(logits_for_positions, dim=-1)
            sampled_tokens = torch.multinomial(probs, 1).squeeze(-1)

            # Update decode states.
            for i, token in enumerate(np.array(sampled_tokens)):
                if current_position < start_decoding_positions[i]:
                    continue

                if decode_states[i]._decode_state == DecoderState.States.END:
                    continue

                if token == tokenizer.eos_token:
                    decode_states[i]._decode_state = DecoderState.States.END
                    continue

                decode_states[i].feed_token(token)
                current_tokens[i, current_position + 1] = sampled_tokens[i].item()

            # Update current tokens and positions.
            current_position += 1

        feeded_strings = [x.get_feeded_string() for x in decode_states]
        rv = [x.copy() for x in current_expressions]

        for r, f in zip(rv, feeded_strings):
            if not f:
                continue
            op = f[0]
            shape = f[2:]
            r.extend([op, shape])

        return rv
