from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import os
from torchmetrics import Accuracy
from torchinfo import summary
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
import time
import multiprocessing
from pathlib import Path
from torchvision.models import resnet18, mobilenetv3, mobilenet_v3_small, mobilenet_v3_large
from torch.optim.lr_scheduler import ReduceLROnPlateau
from models.vit_small import ViT
from models.swin import swin_t
from sklearn.metrics import average_precision_score, f1_score

# The baseline models used models from https://github.com/kentaroy47/vision-transformers-cifar10

# ---------------------------------------------------------------------
# Datasets
size = 32  # 32 for CIFAR, 48 for FER
normalize1 = transforms.Normalize((0.485,), (0.229,))
normalize2 = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
normalize3 = transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
trans_FER = transforms.Compose([transforms.ToTensor(), normalize1])
trans_CIFAR10 = transforms.Compose([transforms.ToTensor(), normalize2])
trans_CIFAR100 = transforms.Compose([transforms.ToTensor(), normalize3])
trans2 = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

# Additional data enhancements:
#CIFAR10:
mean = (0.4914, 0.4822, 0.4465)
std = (0.2023, 0.1994, 0.2010)
#CIFAR100:
#mean = (0.5071, 0.4867, 0.4408)
#std = (0.2675, 0.2565, 0.2761)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
#train_val_dataset = datasets.ImageFolder(root="data/fer2013/train", transform=trans_FER)
#test_dataset = datasets.ImageFolder(root="data/fer2013/test", transform=trans_FER)
train_val_dataset = datasets.CIFAR10("./data/", train=True, download=True, transform=trans_CIFAR10)
test_dataset = datasets.CIFAR10("./data/", train=False, download=True, transform=trans_CIFAR10)
#train_val_dataset = datasets.CIFAR10("./data/", train=True, download=True, transform=transform_train)
#test_dataset = datasets.CIFAR10("./data/", train=False, download=True, transform=transform_test)
#train_val_dataset = datasets.CIFAR100("./data/", train=True, download=True, transform=trans_CIFAR100)
#test_dataset = datasets.CIFAR100("./data/", train=False, download=True, transform=trans_CIFAR100)

valid_size=0.1
random_seed = 1
batch_size = 128
num_workers = 1
pin_memory = False
num_train = len(train_val_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_dataloader = torch.utils.data.DataLoader(
        train_val_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True
    )

val_dataloader = torch.utils.data.DataLoader(
        train_val_dataset,
        batch_size=batch_size,
        sampler=valid_sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True
    )

test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True
    )
# ---------------------------------------------------------------------
# Models
class ResNet18ForCIFAR(nn.Module):
    def __init__(self, num_classes: int = 10, pretrained=False):
        super().__init__()
        self.model = resnet18(pretrained=False)#True
        self.model.fc = nn.Linear(512, num_classes)  # Adapt output for 10 classes

    def forward(self, x):
        return self.model(x)


class MobileNetV3SmallCIFAR(nn.Module):
    """
    MobileNetV3‑Small for CIFAR‑10 / CIFAR‑100.
      • first conv stride 1  (32×32 input)
      • classifier: 576 → 1024 → num_classes
    """
    def __init__(self, num_classes: int = 10, pretrained: bool = False):
        super().__init__()

        # ---- load backbone ----
        self.model = mobilenet_v3_large(weights=None)#'IMAGENET1K_V1'

        # ---- adapt first conv (stride 1) ----
        old_conv = self.model.features[0][0]          # Conv2d(3,16,ks=3,stride=2,pad=1)
        new_conv = nn.Conv2d(
            in_channels=3, out_channels=16,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        # copy weights shrunk to stride‑1 conv
        new_conv.weight.data.copy_(old_conv.weight.data)
        self.model.features[0][0] = new_conv          # keep BN + activation

        # ---- replace LAST Linear only ----
        # original classifier: [0] Linear(576,1024), [1] Hardswish, [2] Dropout, [3] Linear(1024,1000)
        in_features_last = self.model.classifier[-1].in_features  # 1024
        self.model.classifier[-1] = nn.Linear(in_features_last, num_classes)

    def forward(self, x):
        return self.model(x)



if __name__ == "__main__":
    multiprocessing.freeze_support()
    num_classes = 10 # 10 for CIFAR, 7 for FER
    net = ViT(
        image_size=size,
        patch_size=4,
        num_classes=num_classes,
        dim=int(512),
        depth=4,
        heads=6,
        mlp_dim=256,
        dropout=0.1,
        emb_dropout=0.1
    )

    #model = ResNet18ForCIFAR(num_classes=num_classes)
    #model = MobileNetV3SmallCIFAR(num_classes=num_classes)
    model = net

    print(summary(model=model, input_size=(1, 3, 32, 32)))

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
    accuracy = Accuracy(task='multiclass', num_classes=num_classes)
    lr_patience = 10
    scheduler = ReduceLROnPlateau(optimizer, "min", patience=lr_patience)

    # Experiment tracking
    timestamp = datetime.now().strftime("%Y-%m-%d")
    experiment_name = "CIFAR"
    model_name = "CNN"
    log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
    writer = SummaryWriter(log_dir)

    # device-agnostic setup
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    accuracy = accuracy.to(device)
    model = model.to(device)

    EPOCHS = 100

    for epoch in tqdm(range(EPOCHS)):
        print("training started")
        tic = time.time()
        # Training loop
        train_loss, train_acc = 0.0, 0.0
        for X, y in train_dataloader:
            X, y = X.to(device), y.to(device)

            model.train()

            y_pred = model(X)

            loss = loss_fn(y_pred, y)
            train_loss += loss.item()

            acc = accuracy(y_pred, y)
            train_acc += acc

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss /= len(train_dataloader)
        train_acc /= len(train_dataloader)

        # Validation loop
        val_loss, val_acc = 0.0, 0.0
        model.eval()
        with torch.inference_mode():
            for X, y in val_dataloader:
                X, y = X.to(device), y.to(device)

                y_pred = model(X)

                loss = loss_fn(y_pred, y)
                val_loss += loss.item()

                acc = accuracy(y_pred, y)
                val_acc += acc


            val_loss /= len(val_dataloader)
            val_acc /= len(val_dataloader)

            scheduler.step(-val_acc)
        writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train/loss": train_loss, "val/loss": val_loss},
                           global_step=epoch)
        writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train/acc": train_acc, "val/acc": val_acc},
                           global_step=epoch)
        toc = time.time()
        print(
            f"Epoch: {epoch}| Train time: {toc - tic: .2f}|  lr: {optimizer.param_groups[0]['lr']: .5f} Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")


    MODEL_PATH = Path("models")
    MODEL_PATH.mkdir(parents=True, exist_ok=True)

    #MODEL_NAME = "ResNet_CIFAR100.pth"
    #MODEL_NAME = "MobileNetl_FER.pth"
    MODEL_NAME = "ViT_FER.pth"
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

    # Saving the model
    print(f"Saving the model: {MODEL_SAVE_PATH}")
    torch.save(obj=model.state_dict(), f=MODEL_SAVE_PATH)

    # Loading the saved model
    net = ViT(
        image_size=size,
        patch_size=4,
        num_classes=num_classes,
        dim=int(512),
        depth=4,
        heads=6,
        mlp_dim=256,
        dropout=0.1,
        emb_dropout=0.1
    )
    model_loaded = net
    # model_loaded = ResNet18ForCIFAR(num_classes = num_classes)
    # model_loaded = MobileNetV3SmallCIFAR(num_classes = num_classes)
    #'''
    model_loaded.load_state_dict(torch.load(MODEL_SAVE_PATH))

    test_loss, test_acc = 0, 0

    model_loaded.to(device)

    model_loaded.eval()
    all_probs = []  # softmax scores  [N,C]
    all_targets = []  # ground‑truth    [N]

    start_time = time.time()
    with torch.inference_mode():
        for X, y in test_dataloader:
            X, y = X.to(device), y.to(device)
            y_pred = model_loaded(X)


            test_loss += loss_fn(y_pred, y)
            test_acc += accuracy(y_pred, y)

            probs = torch.softmax(y_pred.detach(), dim=1)  # [B,C]
            all_probs.append(probs.cpu())
            all_targets.append(y.cpu())

        test_loss /= len(test_dataloader)
        test_acc /= len(test_dataloader)

    end_time = time.time()

    probs = torch.cat(all_probs)  # [N,C]
    targets = torch.cat(all_targets)  # [N]
    num_cls = probs.size(1)

    # ----------  macro mAP  ----------
    y_true = torch.nn.functional.one_hot(targets, num_classes=num_cls)
    mAP = average_precision_score(y_true.numpy(), probs.numpy(),
                                  average='macro')

    # ----------  macro F1  -----------
    pred1 = probs.argmax(1)
    f1 = f1_score(targets.numpy(), pred1.numpy(), average='macro')

    # ----------  top‑k (k=3) ----------
    topk = 5
    _, predk = probs.topk(topk, dim=1, largest=True, sorted=False)
    topk_ok = (predk == targets.view(-1, 1)).any(dim=1).float().mean().item() * 100
    print(f"- val mAP: {mAP:.3f} - val F1: {f1:.3f} - val top5: {topk_ok:.2f}")

    print(f"Test loss: {test_loss: .5f}| Test acc: {test_acc: .5f}")
    print(f"infer time={(end_time - start_time) / (len(test_dataloader) * batch_size)}")