import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from transformers import ViTModel, ViTImageProcessor
import cv2
import numpy as np
from torch.amp import GradScaler, autocast
from sklearn.model_selection import train_test_split
import pandas as pd
from collections import Counter
from sklearn.metrics import precision_score, recall_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from torchmetrics.classification import MulticlassF1Score
import os
from torchvision import transforms
import json

# Configuration class
class cfg:
    SEED = 42
    batch_size = 8
    test_size = 0.2
    class_num = 9
    num_epochs = 100

# Device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data
labels = pd.read_csv("./data/big2015/big2015_Labels.csv")
x = labels['Id'].values
y = labels['Class'].values.astype(int) - 1
x_train_name, x_test_name, y_train_lable, y_test_lable = train_test_split(
    x, y, test_size=cfg.test_size, random_state=cfg.SEED
)

# Data paths
img_train_path = './data/big2015/gray_images_file_name_train_big'
img_test_path = './data/big2015/gray_images_file_name_test_big'

# Data augmentation (used during training)
train_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
])

# No transformation for testing
test_transform = transforms.Compose([])

# Dataset definition (only grayscale images)
class MultimodalDataset(Dataset):
    def __init__(self, gray_img_path, fl_names, y_label, is_train=True):
        self.gray_root_path = gray_img_path
        self.fl_names = fl_names
        self.gray_img_file_path = [os.path.join(self.gray_root_path, f"{name}.png") for name in self.fl_names]
        self.y_data = y_label
        self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
        self.is_train = is_train
        self.transform = train_transform if is_train else test_transform

    def __getitem__(self, index):
        # Grayscale image
        gray_image = cv2.imread(self.gray_img_file_path[index])
        if gray_image is None:
            raise FileNotFoundError(f"Grayscale image not found: {self.gray_img_file_path[index]}")
        gray_image = cv2.cvtColor(gray_image, cv2.COLOR_BGR2RGB)
        gray_image = self.transform(torch.from_numpy(gray_image).permute(2, 0, 1)).permute(1, 2, 0).numpy()
        inputs = self.processor(images=gray_image, return_tensors="pt")

        label = torch.tensor(self.y_data[index], dtype=torch.long)
        return inputs, label

    def __len__(self):
        return len(self.y_data)

# Custom batch collation function (only grayscale images)
def collate_fn(batch):
    inputs, labels = zip(*batch)
    pixel_values = torch.cat([inp["pixel_values"] for inp in inputs], dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    return {"pixel_values": pixel_values}, labels

# Model definition (only grayscale images)
class MultiModalModel(nn.Module):
    def __init__(self, num_classes, hidden_size=768):
        super(MultiModalModel, self).__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.vit_dim = 768
        self.gray_fc = nn.Linear(self.vit_dim, hidden_size)
        self.norm = nn.LayerNorm(hidden_size)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, images):
        pixel_values = images["pixel_values"].to(device)
        gray_features = self.vit(pixel_values).last_hidden_state.mean(dim=1)
        gray_features = self.gray_fc(gray_features)
        gray_features = self.norm(gray_features)
        output = self.fc(gray_features)
        return output

# Training and evaluation function
save_dir = './results/mmt_vit/mmt_vit_results/gray_only_epoch'
def train_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs, accum_steps=4, save_dir=save_dir):
    writer = SummaryWriter('./results/mmt_vit/mmt_vit_results/runs/gray_only')
    scaler = GradScaler('cuda')
    os.makedirs(save_dir, exist_ok=True)
    train_losses, train_accs, test_accs, test_losses = [], [], [], []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0
        for i, (images, labels) in enumerate(train_loader):
            images = {k: v.to(device) for k, v in images.items()}
            labels = labels.to(device)

            with autocast('cuda'):
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss / accum_steps).backward()
            if (i + 1) % accum_steps == 0 or (i + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item() * labels.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc)

        model.eval()
        test_loss = 0.0
        test_corrects = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = {k: v.to(device) for k, v in images.items()}
                labels = labels.to(device)
                with autocast('cuda'):
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                test_loss += loss.item() * labels.size(0)
                _, preds = torch.max(outputs, 1)
                test_corrects += torch.sum(preds == labels.data)

        test_loss = test_loss / len(test_loader.dataset)
        test_acc = test_corrects.double() / len(test_loader.dataset)
        test_losses.append(test_loss)
        test_accs.append(test_acc)

        writer.add_scalar("Train/Loss", epoch_loss, epoch)
        writer.add_scalar("Train/Accuracy", epoch_acc, epoch)
        writer.add_scalar("Test/Loss", test_loss, epoch)
        writer.add_scalar("Test/Accuracy", test_acc, epoch)

        print(f"Epoch {epoch}/{num_epochs}, Train Loss: {epoch_loss:.8f}, Train Acc: {epoch_acc:.8f}, Test Loss: {test_loss:.8f}, Test Acc: {test_acc:.8f}")
        scheduler.step()

        checkpoint_path = os.path.join(save_dir, f'epoch{epoch}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_losses': train_losses,
            'train_accs': train_accs,
            'test_losses': test_losses,
            'test_accs': test_accs
        }, checkpoint_path)

    best_epoch = np.argmax([acc.cpu() for acc in test_accs])
    best_checkpoint_path = os.path.join(save_dir, f'epoch{best_epoch}.pth')
    checkpoint = torch.load(best_checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Best model loaded from {best_checkpoint_path}, best test accuracy: {test_accs[best_epoch]:.8f}")

    save_path = './results/mmt_vit/mmt_vit_results/model'
    os.makedirs(save_path, exist_ok=True)
    model_path = os.path.join(save_path, 'gray_only.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }, model_path)
    writer.close()

    return model

# Testing function
def test(model, test_loader):
    model.eval()
    total_correct = 0
    total_samples = 0
    y_pred_proba = []
    y_true = []
    f1_metric = MulticlassF1Score(num_classes=cfg.class_num).to(device)

    with torch.no_grad():
        for images, labels in test_loader:
            images = {k: v.to(device) for k, v in images.items()}
            labels = labels.to(device)
            with autocast('cuda'):
                outputs = model(images)
            _, preds = torch.max(outputs, 1)
            probabilities = F.softmax(outputs, dim=1)
            y_pred_proba.extend(probabilities.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)
            f1_metric.update(preds, labels)

    accuracy = total_correct / total_samples
    y_pred = np.argmax(np.array(y_pred_proba), axis=1)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_metric.compute()

    # Compute confusion matrix
    conf_matrix = confusion_matrix(y_true, y_pred, labels=range(cfg.class_num))
    conf_matrix_norm = confusion_matrix(y_true, y_pred, labels=range(cfg.class_num), normalize='true')

    # Save results to JSON
    results = {
        "accuracy": float(accuracy),
        "precision": float(precision),
        "recall": float(recall),
        "f1_score": float(f1),
        "confusion_matrix": conf_matrix.tolist(),
        "normalized_confusion_matrix": conf_matrix_norm.tolist(),
        "class_labels": [f"Class {i}" for i in range(cfg.class_num)]
    }


    json_path = './results/mmt_vit/mmt_vit_results/gray_only_metrics/gray_only.json'
    os.makedirs(os.path.dirname(json_path), exist_ok=True)
    try:
        with open(json_path, 'w') as f:
            json.dump(results, f, indent=4)
        print(f"Test results saved to {json_path}")
    except Exception as e:
        print(f"Error saving JSON results: {e}")

    # Print evaluation metrics
    print(f'Accuracy: {accuracy:.4f}')
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    # Visualize confusion matrix
    class_labels = [f'Class {i}' for i in range(cfg.class_num)]
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_labels,
                yticklabels=class_labels)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.savefig('./results/mmt_vit/mmt_vit_results/gray_only_metrics/confusion_matrix_gray_only.png')
    plt.show()

    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix_norm, annot=True, fmt='.2%', cmap='Blues',
                xticklabels=class_labels,
                yticklabels=class_labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Normalized Confusion Matrix')
    plt.savefig('./results/mmt_vit/mmt_vit_results/gray_only_metrics/normalized_confusion_matrix_gray_only.png')
    plt.show()

# Main program
if __name__ == '__main__':
    # Initialize datasets
    dataset_train = MultimodalDataset(
        img_train_path, x_train_name, y_train_lable, is_train=True
    )
    dataset_test = MultimodalDataset(
        img_test_path, x_test_name, y_test_lable, is_train=False
    )
    train_loader = DataLoader(dataset_train, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(dataset_test, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    # Initialize model
    model = MultiModalModel(num_classes=cfg.class_num, hidden_size=768).to(device)

    # Freeze ViT parameters
    for name, param in model.vit.named_parameters():
        if "encoder.layer" in  name and int(name.split('.')[2]) >= 8:
            param.requires_grad = True
        else:
            param.requires_grad = False

    # Loss function and optimizer
    class_weights = torch.tensor(
        [1.0 / max(Counter(y_train_lable).get(i, 1), 1) for i in range(cfg.class_num)], dtype=torch.float
    ).to(device)
    print(f"Class weights: {class_weights}")
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=len(train_loader), epochs=50)

    # Debug initial forward pass
    model.eval()
    print("Debugging initial forward pass...")
    with torch.no_grad():
        for images, labels in train_loader:
            images = {k: v.to(device) for k, v in images.items()}
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            print(f"Initial sample probabilities: {probs[0]}")
            print(f"Initial predictions: {torch.argmax(probs, dim=1)[:5]}")
            print(f"True labels: {labels[:5]}")
            break

    # Start training
    print("Starting training...")
    model = train_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs=cfg.num_epochs)

    # Test the best model
    print("Testing the best model...")
    test(model, test_loader)

