import torch
from torch import nn
import torch.nn.functional as F
from argparse import ArgumentParser
from transformers import AutoModelForCausalLM

from configs import Config
from utils import (
    dataclass_from_file,
    setup_seed,
    get_device,
    print_trainable_parameters,
)


class Autoencoder(nn.Module):
    def __init__(self, config: "Config") -> None:
        super().__init__()
        self.config = config
        self.embedding_size = config.embedding_encoder.embedding_size
        self.hidden_size = config.embedding_encoder.unsafe_config.get(
            "hidden_size", 768
        )

        self.down = nn.Linear(self.embedding_size, self.hidden_size)
        self.up = nn.Linear(self.hidden_size, self.embedding_size)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        return self.down(x)

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        return self.up(x)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decode(self.encode(x))


if __name__ == "__main__":
    setup_seed()
    device = get_device()

    torch.set_float32_matmul_precision("high")

    parser = ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()

    config = dataclass_from_file(Config, args.config)

    model = Autoencoder(config).to(config.dtype).to(device)

    embeddings: nn.Embedding = AutoModelForCausalLM.from_pretrained(
        config.pretrained_model_name_or_path,
        torch_dtype=config.dtype,
        device_map=get_device(),
    ).get_input_embeddings()
    data = embeddings.weight.data.clone()

    u, s, v = torch.pca_lowrank(data.to(torch.float32), q=768)

    print(u.shape, s.shape, v.shape)

    # reconstruct
    print(torch.norm(data - u @ torch.diag(s) @ v.T, dim=-1).mean())

    # print_trainable_parameters(model)

    # if config.compile_model and device == "cuda":
    #     model = torch.compile(model)

    # optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

    # model.train()
    # for step in range(250):
    #     output = model(data)

    #     loss = F.mse_loss(output, data, reduction="sum")
    #     loss.backward()

    #     norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    #     optimizer.step()
    #     optimizer.zero_grad()

    #     print(f"step {step:5d} | loss: {loss.item():.6f} | norm: {norm.item():.6f}")
