#!/usr/bin/env python3
import os
import time
from pathlib import Path
from typing import Type
from argparse import Namespace

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from timm.loss import SoftTargetCrossEntropy
from timm.scheduler import CosineLRScheduler
from timm.data import Mixup, create_loader
from timm.utils import accuracy
import lightning as L

from modern_hopfield_attention.model import MHAVisionTransformer, UniversalMHAViT
from modern_hopfield_attention.data import create_vision_dataset


class LitMHAViT(L.LightningModule):
    def __init__(self, args: Namespace) -> None:
        super().__init__()
        # save
        self.save_hyperparameters(args)

        # args
        self.save_dir = args.save_dir

        # model
        ## universal transformer
        if args.universal:

            self.model = UniversalMHAViT(
                attn_alpha=args.attn_alpha,
                skip_alpha=args.skip_alpha,
                img_size=args.img_size,
                patch_size=args.patch_size,
                in_chans=args.in_chans,
                num_classes=args.num_classes,
                embed_dim=args.embed_dim,
                depth=args.depth,
                num_heads=args.num_heads,
                class_token=args.class_token,
                drop_path_rate=args.drop_path_rate,
            )

        else:

            self.model = MHAVisionTransformer(
                attn_alpha=args.attn_alpha,
                skip_alpha=args.skip_alpha,
                img_size=args.img_size,
                patch_size=args.patch_size,
                in_chans=args.in_chans,
                num_classes=args.num_classes,
                embed_dim=args.embed_dim,
                depth=args.depth,
                num_heads=args.num_heads,
                class_token=args.class_token,
                drop_path_rate=args.drop_path_rate,
            )

        # loss
        self.train_loss_fn = SoftTargetCrossEntropy()
        self.valid_loss_fn = nn.CrossEntropyLoss()

        # optimizer&scheduler
        self.optimizer = optim.AdamW(
            self.parameters(),
            lr=args.learning_rate,
            eps=args.eps,
            betas=args.betas,
            weight_decay=args.weight_decay,
        )
        self.scheduler = CosineLRScheduler(
            optimizer=self.optimizer,
            t_initial=args.t_initial,
            lr_min=args.lr_min,
            warmup_t=args.warmup_t,
            warmup_lr_init=args.warmup_lr_init,
            warmup_prefix=True,
        )
        # mixup
        self.mixip_fn = Mixup(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix,
            label_smoothing=args.label_smoothing,
            num_classes=args.num_classes,
        )

        # dataloader
        train_set = create_vision_dataset(
            args.dataset_type,
            args.dataset_dir,
            split="train",
            is_training=True,
            download=True,
            batch_size=args.batch_size,
            repeats=0,
        )
        self.train_loader = create_loader(
            train_set,
            input_size=(args.in_chans, args.img_size, args.img_size),
            batch_size=args.batch_size,
            is_training=True,
            # random erasing
            re_prob=args.re_prob,
            re_mode=args.re_mode,
            re_count=args.re_count,
            re_split=args.re_split,
            # crop
            crop_pct=args.crop_pct,
            # rand augument
            auto_augment=args.auto_augment,
            use_prefetcher=False,
            device=self.device,
        )
        valid_set = create_vision_dataset(
            args.dataset_type,
            args.dataset_dir,
            split="validation",
            is_training=False,
            download=True,
            batch_size=args.batch_size,
            repeats=0,
        )
        self.valid_loader = create_loader(
            valid_set,
            input_size=(args.in_chans, args.img_size, args.img_size),
            batch_size=32,
            is_training=False,
            # crop
            crop_pct=args.crop_pct,
            use_prefetcher=False,
            device=self.device,
        )

    def on_fit_start(self) -> None:
        self.model.register_hooks()

    def on_train_batch_start(self, batch, batch_idx) -> None:
        self.model.clear_hooks()

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        x_plain, y_plain = batch
        B = x_plain.size(0)

        # mixup
        x, y = self.mixip_fn(x_plain, y_plain)
        y_hat = self.model(x)

        loss = self.train_loss_fn(y_hat, y)
        top1, top5, top10 = accuracy(y_hat, y_plain, topk=(1, 5, 10))

        # log
        self.log(
            f"train/loss",
            value=loss.item(),
            batch_size=B,
            on_step=True,
            sync_dist=True,
        )
        self.log(
            f"train/top1",
            value=top1.item(),
            on_step=True,
            sync_dist=True,
        )
        self.log(
            f"train/top5",
            value=top5.item(),
            on_step=True,
            sync_dist=True,
        )
        self.log(
            f"train/top10",
            value=top10.item(),
            on_step=True,
            sync_dist=True,
        )

        return loss

    def on_validation_start(self):
        self.model.clear_hooks()

    def validation_step(
        self,
        batch: tuple[torch.Tensor, torch.Tensor],
        batch_idx: int,
    ) -> None:
        x, y = batch
        B = x.size(0)

        y_hat = self.model(x)
        loss = self.valid_loss_fn(y_hat, y)

        top1, top5, top10 = accuracy(y_hat, y, topk=(1, 5, 10))


        self.log(
            f"valid/loss",
            value=loss.item(),
            batch_size=B,
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"valid/top1",
            value=top1.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"valid/top5",
            value=top5.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            f"valid/top10",
            value=top10.item(),
            on_step=False,
            on_epoch=True,
            sync_dist=True,
        )

        return

    def configure_optimizers(self):  # -> tuple[list[AdamW], list[dict[str, Any]]]:

        return [self.optimizer], [{"scheduler": self.scheduler, "interval": "epoch"}]

    def lr_scheduler_step(self, scheduler, metric) -> None:
        scheduler.step(epoch=self.current_epoch)

    def train_dataloader(self) -> DataLoader:
        return self.train_loader

    def val_dataloader(self) -> DataLoader:
        return self.valid_loader
