import torch
import torch.nn as nn
import numpy as np
from einops import rearrange

from models.transformer import Transformer as TransformerEncoderDecoder
from models.transformer_decoder import Transformer
from quant.models.cnn import Autoencoder
from rfcutils2.qpsk_helper_fn import qpsk_matched_filter_demod as qpsk_demod
from rfcutils2.qpsk_helper_fn import modulate_qpsk_signal as qpsk_mod
import utils.data_transforms as data_transforms
from utils.utils import describe_tensor


class QPSKTokenizer(nn.Module):
    def __init__(self, use_bits=False):
        super().__init__()
        self.use_bits = use_bits

    def encode(self, x):
        x = data_transforms.interleaving_to_complex(x)
        bits = torch.from_numpy(qpsk_demod(x.cpu().numpy())[0].numpy()).to(x.device).long()
        if self.use_bits:
            return bits
        else:
            return bits[:, ::2] * 2 + bits[:, 1::2]

    def decode(self, x):
        if self.use_bits:
            bits = x
        else:
            bits = rearrange([torch.floor_divide(x, 2), torch.remainder(x, 2)], "t b v -> b (v t)")
        waveform = torch.from_numpy(qpsk_mod(bits.cpu().numpy())[0].numpy()).to(x.device)
        return data_transforms.complex_to_interleaving(waveform)


def apply_model(input, func, w):
    """
    Input has shape (B, N)
    func is capable of taking (B, w) and return (B, w_out)
    We map (B, N) to (B, w_out * (N / w)) by working independently on windows
    """
    s = input.shape[1] // w
    input = rearrange(input, "b (s w) -> (b s) w", w=w)
    output = func(input)
    return rearrange(output, "(b s) wo -> b (s wo)", s=s)


# waveform2 format has shape [B, signal_length * 2],
# where real and imaginary parts of the waveform interleave


class QuantOutputTransformer(nn.Module):
    def __init__(
        self,
        tokenizer_path,
        tokenizer_config,
        transformer_config,
        llm_style,
        tokenize_input = False,
        use_same_mixture_tokenizer = False,
        mixture_tokenizer_config = None,
        mixture_tokenizer_path = None,
        tokenizer_type = "autoencoder",
    ):
        super().__init__()
        if tokenizer_type == "autoencoder":
            self.tokenizer = Autoencoder(**tokenizer_config)
            tokenizer_state_dict = torch.load(tokenizer_path)
            filtered_state_dict = {
                k: v for k, v in tokenizer_state_dict.items()
                if not k.endswith("conv_mask")
            }
            self.tokenizer.load_state_dict(filtered_state_dict, strict=False)
            self.tokenizer.requires_grad_(False)
        elif tokenizer_type == "gt":
            self.tokenizer = QPSKTokenizer(**tokenizer_config)
        else:
            assert False, "Unknown tokenizer type"
        self.transformer = TransformerEncoderDecoder(**transformer_config) if llm_style else Transformer(**transformer_config)
        self.llm_style = llm_style
        self.tokenize_input = tokenize_input
        if self.tokenize_input:
            if use_same_mixture_tokenizer:
                self.mixture_tokenizer = self.tokenizer
            else:
                assert mixture_tokenizer_path is not None
                assert mixture_tokenizer_path is not None
                self.mixture_tokenizer = Autoencoder(**mixture_tokenizer_config)
                mixture_tokenizer_state_dict = torch.load(mixture_tokenizer_path)
                self.mixture_tokenizer.load_state_dict(mixture_tokenizer_state_dict)
                self.mixture_tokenizer.requires_grad_(False)

    def tokenizer_encode(self, signal):
        return apply_model(signal, lambda x: self.tokenizer.encode(x), self.tokenizer.get_input_length())

    def tokenizer_decode(self, tokens):
        return apply_model(tokens, lambda x: self.tokenizer.decode(x), self.tokenizer.get_token_count())

    def encode(self, target):
        assert len(target.shape) == 2, "target should be in waveform2 format"
        return self.tokenizer_encode(target)

    def forward(self, cond, target=None):
        if self.tokenize_input:
            cond = self.mixture_tokenizer_encode(cond)
        if self.llm_style:
            assert target is not None
            target = torch.roll(target, shifts=1, dims=1)
            target[:, 0] = self.transformer.start_token
            logits = self.transformer(target, cond)
        else:
            logits = self.transformer(cond)
        return logits

    @torch.inference_mode()
    def decode_logits(self, logits):
        tokens = torch.argmax(logits, dim=2)
        return self.tokenizer_decode(tokens)

    @torch.inference_mode()
    def generate(self, cond, beam_k=1):
        if self.tokenize_input:
            cond = self.mixture_tokenizer_encode(cond)
        if self.llm_style:
            tokens = self.transformer.generate(cond, beam_k=beam_k)
        else:
            logits = self.transformer.generate(cond)
            tokens = torch.argmax(logits, dim=2)
        decoded = self.tokenizer_decode(tokens)
        return decoded

