#!/usr/bin/env python3

import os
import time
from pathlib import Path
import argparse
import inspect

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from timm.scheduler import CosineLRScheduler
import lightning as L
from transformers import AutoTokenizer

from modern_hopfield_attention.model import  GPT2_ori
from modern_hopfield_attention.data import WikiText103
from modern_hopfield_attention.functional import token_cossim


class LitGPT2(L.LightningModule):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__()
        # save
        self.save_hyperparameters(args)

        # args
        self.save_dir = args.save_dir
        self.learning_rate = args.learning_rate
        self.betas = args.betas
        self.weight_decay = args.weight_decay

        # tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("gpt2")

        self.model = GPT2_ori(
            vocab_size=50304 or self.tokenizer.vocab_size,
            sequence_length=args.num_tokens,
            num_heads=args.num_heads,
            embedding_dim=args.embed_dim,
            depth=args.depth,
        )

        # loss
        self.loss_fn = nn.CrossEntropyLoss()

        # optimizer&scheduler
        self.optimizer = self.setting_optimizer()

        self.scheduler = CosineLRScheduler(
            optimizer=self.optimizer,
            t_initial=args.t_initial,
            lr_min=args.lr_min,
            warmup_t=args.warmup_t,
            warmup_lr_init=args.warmup_lr_init,
            warmup_prefix=True,
        )

        # dataloader
        train_set = WikiText103(
            root=args.dataset_dir,
            split="train",
            tokenizer=self.tokenizer,
            num_proc=args.num_proc,
        )
        print(f"{len(train_set)=}")
        self.train_loader = DataLoader(
            train_set,
            shuffle=False,
            batch_size=args.batch_size,
        )

        valid_set = WikiText103(
            root=args.dataset_dir,
            split="validation",
            tokenizer=self.tokenizer,
            num_proc=args.num_proc,
        )
        print(f"{len(valid_set)=}")

        self.valid_loader = DataLoader(
            valid_set,
            shuffle=False,
            batch_size=args.batch_size,
        )

    def setting_optimizer(
        self,
    ) -> torch.optim.AdamW:
        param_dict = {
            name: param
            for name, param in self.model.named_parameters()
            if param.requires_grad
        }
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [param for _, param in param_dict.items() if param.dim() >= 2]
        nodecay_params = [param for _, param in param_dict.items() if param.dim() < 2]
        optim_groups = [
            {"params": decay_params, "weight_decay": self.weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        num_decay_params = sum(param.numel() for param in decay_params)
        num_nodecay_params = sum(param.numel() for param in nodecay_params)
        print(
            f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
        )
        print(
            f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
        )
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
        print(fused_available)
        extra_args = dict(fused=True) if fused_available else dict()
        optimizer = optim.AdamW(
            optim_groups, lr=self.learning_rate, betas=self.betas, **extra_args
        )
        print(f"using fused AdamW: {fused_available}")

        return optimizer

    def on_fit_start(self) -> None:
        self.model.register_hooks()

    def on_train_batch_start(self, batch, batch_idx) -> None:
        self.model.clear_hooks()

    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        x = batch
        y = batch[:, 1:].contiguous()
        B = x.size(0)

        logits = self.model(x)
        logits = logits[:, :-1, :].contiguous()
        logits = logits.view(-1, logits.size(-1))

        loss = self.loss_fn(logits, y.view(-1))

        # log
        self.log(
            f"train/loss",
            value=loss.item(),
            prog_bar=True,
            sync_dist=True,
        )
        self.log(
            f"train/perplexity",
            value=torch.exp(loss).item(),
            sync_dist=True,
        )

        return loss

    def on_validation_start(self):
        self.model.clear_hooks()

    def validation_step(
        self,
        batch: torch.Tensor,
        batch_idx: int,
    ) -> None:
        x = batch
        y = batch[:, 1:].contiguous()
        B = x.size(0)

        start_throughput = time.time()

        logits = self.model(x)
        logits = logits[:, :-1, :].contiguous()
        logits = logits.view(-1, logits.size(-1))

        loss = self.loss_fn(logits, y.view(-1))

        throughput_time = time.time() - start_throughput

        self.log(
            "valid/throughput",
            value=throughput_time,
            sync_dist=True,
        )

        # hook
        if batch_idx == 0 and self.global_rank == 0:
            hook_input = torch.stack(self.model.hook_input, dim=1)
            # save
            os.makedirs(
                Path(self.save_dir) / f"hook/epoch{self.current_epoch:03}",
                0o777,
                exist_ok=True,
            )
            torch.save(
                hook_input,
                Path(self.save_dir)
                / f"hook/epoch{self.current_epoch:03}/hook_input.pt",
            )

            similarity = token_cossim(hook_input, vs_clstoken=False)
            modes = similarity.view(similarity.size(1), -1).mode()[0]

            for idx, mode in enumerate(modes):
                self.log(
                    f"valid/cls_uni(layer={idx:02})",
                    value=mode,
                    sync_dist=True,
                    rank_zero_only=True,
                )

        self.log(
            f"valid/loss",
            value=loss.item(),
            prog_bar=True,
            sync_dist=True,
        )
        self.log(
            f"valid/perplexity",
            value=torch.exp(loss).item(),
            sync_dist=True,
        )
        return

    def configure_optimizers(self):  # -> tuple[list[AdamW], list[dict[str, Any]]]:

        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "epoch"}]

    def lr_scheduler_step(self, scheduler, metric) -> None:
        scheduler.step(epoch=self.current_epoch)

    def train_dataloader(self) -> DataLoader:
        return self.train_loader

    def val_dataloader(self) -> DataLoader:
        return self.valid_loader
