import json
import dataclasses as dc
from pathlib import Path
from litgpt.prompts import PromptStyle, Default as NoPrompt
from typing import Sequence, Iterable, Final, Any
from core.tokenization import Vocabulary
from io import TextIOBase

from .formulation import CoT, StepwiseCoT, MDPStep, MDPCoT


_NAMED_TEXT_STYLES: dict[str, type['Style']] = {}


def _cat(*items: str | None, sep: str = ''):
    return sep.join(item for item in items if item)


@dc.dataclass
class Segment:
    """
    The style to conveniently present a segment of textual content:
    `{beg_token}{space}item{sep}item{sep}...item{space}{end_token}`
    """

    beg_token: Final[str | None] = None
    end_token: Final[str | None] = None
    space: Final[str] = ' '
    sep: Final[str] = ' '

    def __call__(self, *items: str):
        return _cat(self.beg_token, self.content(*items), self.end_token, sep=self.space)
    
    def content(self, *items: str):
        return self.sep.join(items)
    
    def prompt(self, *items: str):
        return self.prefix(*items)

    def prefix(self, *items: str):
        if items:
            return _cat(self.beg_token, self.content(*items), sep=self.space)
        else:
            return self.beg_token or ''
        
    def next_content(self, *items: str, sep_prefix: bool = False) -> str:
        if items:
            front = self.sep if sep_prefix else (self.space if self.beg_token else '')
            return front + self.content(*items)
        else:
            return ''
    
    def suffix(self, *items: str, sep_prefix: bool = False) -> str:
        if items:
            front = self.sep if sep_prefix else (self.space if self.beg_token else '')
            return front + _cat(self.content(*items), self.end_token, sep=self.space)
        else:
            if (sep_prefix or self.beg_token is not None) and self.end_token:
                return self.space + self.end_token
            else:
                return ''


class Style[In, Thought, Out]:
    """
    The class that specifies the textual representation of reasoning.
    """

    all = Segment(sep='\n\n')
    input = Segment('<in>', '</in>')
    thought = Segment('<thought>', '</thought>', space='\n', sep='\n')
    outcome = Segment('<out>', '</out>')

    @classmethod 
    def special_tokens(cls):
        return Vocabulary(
            beg_of_reasoning=cls.all.beg_token,
            end_of_reasoning=cls.all.end_token,
            beg_of_input=cls.input.beg_token,
            end_of_input=cls.input.end_token,
            beg_of_thought=cls.thought.beg_token,
            end_of_thought=cls.thought.end_token,
            beg_of_outcome=cls.outcome.beg_token,
            end_of_outcome=cls.outcome.end_token,
        )
    
    def __init_subclass__(cls, name: str | None = None, **kwargs) -> None:
        super().__init_subclass__()
        if name is not None:
           _NAMED_TEXT_STYLES[name] = cls
    
    @classmethod
    def from_name(cls, name: str) -> 'Style':
        return _NAMED_TEXT_STYLES[name]()
    
    def repr_input(self, input: In, **kwargs) -> str:
        """Represent a input by text."""
        return input if isinstance(input, str) else json.dumps(input)
    
    def repr_thought(self, thought: Thought, **kwargs) -> str:
        """Represent a reasoning by text."""
        return thought if isinstance(thought, str) else json.dumps(thought)
    
    def repr_outcome(self, outcome: Out, **kwargs) -> str:
        """Represent the outcome by text."""
        return outcome if isinstance(outcome, str) else json.dumps(outcome)
    
    def apply_input(self, input: In, **kwargs) -> str:
        """The textual input."""
        return self.input(self.repr_input(input, **kwargs))
    
    def apply_thought(self, thought: Thought, **kwargs) -> str:
        """The textual thought."""
        return self.thought(self.repr_thought(thought, **kwargs))
    
    def apply_outcome(self, outcome: Out, **kwargs):
        """The textual thought."""
        return self.outcome(self.repr_outcome(outcome, **kwargs))
    
    def prompt_all(self, input: In, **kwargs):
        return self.all.prefix(
            self.apply_input(input, **kwargs),
            self.thought.prefix()
        )

    def apply_all(self, input: In, thought: Thought, outcome: Out, **kwargs):
        kwargs.update(input=input, thought=thought, outcome=outcome)
        return self.all(
            self.apply_input(**kwargs),
            self.apply_thought(**kwargs),
            self.apply_outcome(**kwargs),
        )

    def _sftdata_next_token(self, cot: CoT[In, Thought, Out], eos: str = ''):
        kwargs = cot.as_dict()
        prompt = self.prompt_all(**kwargs)
        output = self.all.suffix(
            self.thought.suffix(self.repr_thought(**kwargs)),
            self.apply_outcome(**kwargs)
        ) + eos
        if prompt and output:
            yield dict(instruction=prompt, output=output)

    def save_instances(
        self,
        samples: Sequence[CoT[In, Thought, Out]],
        file: str | Path,
        **kwargs,
    ):
        file = Path(file)
        fmt = ''.join(file.suffixes)
        if not isinstance(samples, list):
            samples = list(samples)

        f = open(file, 'wt')
        self._dump_instances(samples, f, fmt, **kwargs)
        f.close()

    def _sftdata(
        self, cot: CoT[In, Thought, Out], *,
        token_level: bool = True,
        eos: str = '',
        **_,
    ) -> Iterable[dict[str, str]]:
        
        if token_level:
            yield from self._sftdata_next_token(cot, eos)

    def _dump_text(
        self,
        samples: list[CoT[In, Thought, Out]],
        f: TextIOBase,
        *,
        sep: str = '\n\n',
        bos: str = '',
        eos: str = '',
        **_,
    ):
        for i, sample in enumerate(samples):
            temp = sample.as_dict()
            if i > 0:
                f.write(sep + bos + self.apply_all(**temp) + eos)
            else:
                f.write(bos + self.apply_all(**temp) + eos)

    def _dump_instances(
        self,
        samples: list[CoT[In, Thought, Out]],
        f: TextIOBase,
        fmt: str,
        **kwargs,
    ):

        if fmt.endswith('.sft.json'):
            items = [item for s in samples for item in self._sftdata(s, **kwargs)]
            json.dump(items, f, indent=4)
        elif fmt.endswith('.sft.jsonl'):
            for s in samples:
                for item in self._sftdata(s, **kwargs):
                    json.dump(item, f, indent=4)
        elif fmt.endswith('.json'):
            items = [s.as_dict() for s in samples]
            json.dump(items, f, indent=4)
        elif fmt.endswith('.jsonl'):
            for s in samples:
                json.dump(s.as_dict(), f)
        elif fmt == '.txt':
            self._dump_text(samples, f, **kwargs)
        else:
            raise NotImplementedError(f"\"{fmt}\" is not a supported format.")


class DirectStyle[In, Out](Style[In, Any, Out]):
    """The style of reasoning without presenting the intermediate thought."""

    thought  = Segment(beg_token=None, end_token=None)

    def repr_thought(self, thought: Any, **kwargs) -> str:
        return ""


class StepwiseStyle[In, Step, Out](Style[In, list[Step], Out]):
    """The style for reasoning with an explicit multi-step structure."""

    step = Segment('<step>', '</step>')
    ellipsis_token: str = '<...>'

    def repr_step(self, idx: int, step: Step, **kwargs) -> str:
        """Represent a step by text"""
        return step if isinstance(step, str) else json.dumps(step)
    
    def apply_step(self, idx: int, step: Step, **kwargs) -> str:
        return self.step(self.repr_step(idx, step, **kwargs))

    def repr_thought(self, thought: list[Step], **kwargs) -> str:
        steps = (self.apply_step(i, step, **kwargs) for i, step in enumerate(thought))
        return self.thought.content(*steps)
    
    @classmethod
    def special_tokens(cls):
        tokens = super().special_tokens()
        tokens.add_kwtokens(dict(
            beg_of_step=cls.step.beg_token,
            end_of_step=cls.step.end_token,
            ellipsis=cls.ellipsis_token,
        ))
        return tokens

    def _sftdata_step(self, cot: StepwiseCoT[In, Step, Out],
                      include_init=True, include_outcome=True):

        kwargs = cot.as_dict()
        input = self.apply_input(**kwargs)
        
        for i in range(-1 if include_init else 0, len(cot.thought)):
            if i >= 0:
                step = self.apply_step(i, cot.thought[i], **kwargs)
                prompt = self.thought.prefix(self.ellipsis_token, step)
                sep_prefix = True
            else:  # initialize state
                prompt = self.all.prefix(
                    input,
                    self.thought.prefix(),
                )
                sep_prefix = False

            if i + 1 < len(cot.thought):
                next_step = self.apply_step(i + 1, cot.thought[i + 1], **kwargs)
                output = self.thought.next_content(next_step, sep_prefix=sep_prefix)
            elif include_outcome:
                output = self.all.suffix(
                    self.thought.suffix(sep_prefix=sep_prefix),
                    self.apply_outcome(**kwargs)
                )
            else:
                continue
            
            if prompt and output:
                yield dict(instruction=prompt, output=output)

    def _sftdata(
        self, cot: CoT[In, list[Step], Out], *,
        token_level: bool = False,
        step_level: bool = True,
        predict_outcome: bool = True,
        predict_init: bool = True,
        eos: str = '',
        **_,
    ) -> Iterable[dict[str, str]]:
        
        yield from super()._sftdata(cot, token_level=token_level, eos=eos, **_)

        if step_level:
            yield from self._sftdata_step(cot,
                include_init=predict_init,
                include_outcome=predict_outcome
            )


class MDPStyle[In, State, Action, Out](StepwiseStyle[In, MDPStep[State, Action], Out]):
    """The outoput style for a CoT with an explicit MDP structure."""

    step = Segment(None, None, space=' ', sep='\n')
    state = Segment('<state>', '</state>')
    action = Segment('<action>', '</action>')

    def repr_state(self, idx: int, state: State, **kwargs) -> str:
        """Represent a state by text."""
        return state if isinstance(state, str) else json.dumps(state)
    
    def repr_action(self, idx: int, action: Action, **kwargs) -> str:
        """Represent an action by text."""
        return action if isinstance(action, str) else json.dumps(action)

    def apply_state(self, idx: int, state: State, **kwargs):
        return self.state(self.repr_state(idx, state, **kwargs))
    
    def apply_action(self, idx: int, action: Action, **kwargs) -> str:
        return self.action(self.repr_action(idx, action, **kwargs))

    def repr_step(self, idx: int, step: MDPStep[State, Action], **kwargs) -> str | None:
        state_ = step['state']; action_ = step['action']
        state = self.apply_state(idx, state_, **kwargs)
        if action_ is None:
            return self.step.content(state)
        else:
            action = self.apply_action(idx, action_, **kwargs)
            return self.step.content(state, action)
 
    @classmethod
    def special_tokens(cls):
        tokens = super().special_tokens()
        tokens.add_kwtokens(dict(
            beg_of_state=cls.state.beg_token,
            end_of_state=cls.state.end_token,
            beg_of_action=cls.action.beg_token,
            end_of_action=cls.action.end_token,
        ))
        return tokens

    def _sftdata_transition(
        self,
        cot: MDPCoT[In, State, Action, Out],
        include_init=True,
    ):
        kwargs = cot.as_dict()
        input = self.apply_input(**kwargs)
        
        for i in range(-1 if include_init else 0, len(cot.thought)):

            if i + 1 >= len(cot.thought):
                break

            if i >= 0:
                step = self.apply_step(i, cot.thought[i], **kwargs)
                prompt = self.thought.prefix(self.ellipsis_token, step)
                sep_prefix = True
            else:  # initialize first state
                prompt = self.all.prefix(
                    input,
                    self.thought.prefix(),
                )
                sep_prefix = False

            next_state = self.apply_state(i + 1, cot.thought[i + 1]['state'], **kwargs)
            output = self.thought.next_content(
                self.step.prefix(next_state),
                sep_prefix=sep_prefix
            )

            if prompt and output:
                yield dict(instruction=prompt, output=output)

    def _sftdata_pi(
        self,
        cot: MDPCoT[In, State, Action, Out],
        include_outcome=True,
    ):
        kwargs = cot.as_dict()
        
        for i in range(len(cot.thought)):
            state = self.apply_state(i, cot.thought[i]['state'], **kwargs)
            prompt = self.thought.prefix(
                self.ellipsis_token,
                self.step.prefix(state),
            )
            action = cot.thought[i]['action']

            if action is None:
                assert i + 1 == len(cot.thought)
                output = self.thought.suffix(
                    self.step.suffix(sep_prefix=True),
                    sep_prefix=True,
                )
                if include_outcome:
                    output = self.all.suffix(output, self.apply_outcome(**kwargs))
            else:
                output = self.step.suffix(
                    self.action(self.repr_action(i, action, **kwargs)),
                    sep_prefix=True,
                )

            if prompt and output:
                yield dict(instruction=prompt, output=output)

    def _sftdata(
        self,
        cot: CoT[In, list[MDPStep[State, Action]], Out], *,
        token_level: bool = False,
        step_level: bool = False,
        state_level: bool = True,
        policy_level: bool = True,
        predict_outcome: bool = True,
        predict_init: bool = True,
        eos: str = '',
        **_,
    ) -> Iterable[dict[str, str]]:
        
        yield from super()._sftdata(
            cot,
            token_level=token_level,
            step_level=step_level,
            predict_outcome=predict_outcome, 
            predict_init=predict_init,
            eos=eos,
            **_,
        )

        if state_level:
            yield from self._sftdata_transition(cot, include_init=predict_init)
        
        if policy_level:
            yield from self._sftdata_pi(cot, include_outcome=predict_outcome)
