import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import ViTModel, ViTImageProcessor
import cv2
import numpy as np
from torch.amp import autocast
from sklearn.metrics import precision_score, recall_score, confusion_matrix
from sklearn.model_selection import train_test_split
import seaborn as sns
import matplotlib.pyplot as plt
from torchmetrics.classification import MulticlassF1Score
import os
import json
from collections import Counter
import shutil
from torch.optim.lr_scheduler import CosineAnnealingLR

class cfg:
    batch_size = 8
    class_num = 25
    SEED = 42
    num_epochs = 100

torch.manual_seed(cfg.SEED)
np.random.seed(cfg.SEED)
torch.cuda.manual_seed_all(cfg.SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def organize_malimg_data(malimg_root, output_root):
    if not os.path.exists(malimg_root):
        raise FileNotFoundError(f"malimg root directory does not exist: {malimg_root}")

    selected_classes = [d for d in os.listdir(malimg_root) if os.path.isdir(os.path.join(malimg_root, d))]
    if len(selected_classes) != 25:
        print(f"Warning: Found {len(selected_classes)} classes, expected 25")

    os.makedirs(output_root, exist_ok=True)
    for cls_name in selected_classes:
        src_dir = os.path.join(malimg_root, cls_name)
        dst_dir = os.path.join(output_root, cls_name)
        os.makedirs(dst_dir, exist_ok=True)
        img_count = 0
        for img_name in os.listdir(src_dir):
            if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                src_path = os.path.join(src_dir, img_name)
                dst_path = os.path.join(dst_dir, img_name)
                shutil.copy2(src_path, dst_path)
                img_count += 1
    return selected_classes

class MalimgDataset(Dataset):
    def __init__(self, root_dir, selected_classes, image_paths=None, labels=None):
        self.root_dir = root_dir
        self.classes = selected_classes
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.image_paths = image_paths if image_paths is not None else []
        self.labels = labels if labels is not None else []
        self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

        if not self.image_paths or not self.labels:
            self.image_paths = []
            self.labels = []
            if not os.path.exists(root_dir):
                raise FileNotFoundError(f"Root directory does not exist: {root_dir}")

            for cls_name in self.classes:
                cls_dir = os.path.join(root_dir, cls_name)
                if not os.path.isdir(cls_dir):
                    print(f"Warning: Directory {cls_dir} does not exist")
                    continue
                img_count = 0
                for img_name in os.listdir(cls_dir):
                    if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
                        self.image_paths.append(os.path.join(cls_dir, img_name))
                        self.labels.append(self.class_to_idx[cls_name])
                        img_count += 1
                # print(f"Class {cls_name}: Loaded {img_count} images")

        if not self.image_paths:
            raise ValueError(f"No valid images found in {root_dir} for classes {selected_classes}")

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        gray_image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if gray_image is None:
            raise FileNotFoundError(f"Image not found or invalid: {img_path}")

        gray_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB)
        inputs = self.processor(images=gray_image, return_tensors="pt")
        label = torch.tensor(self.labels[index], dtype=torch.long)
        return inputs["pixel_values"].squeeze(0), label

    def __len__(self):
        return len(self.image_paths)

def collate_fn(batch):
    pixel_values, labels = zip(*batch)
    pixel_values = torch.stack(pixel_values, dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    return {"pixel_values": pixel_values}, labels


class GrayOnlyModel(nn.Module):
    def __init__(self, num_classes, hidden_size=768, num_heads=8, num_layers=6):
        super(GrayOnlyModel, self).__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.vit_dim = 768
        self.gray_fc = nn.Linear(self.vit_dim, hidden_size)
        self.fusion_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads, dim_feedforward=2048, dropout=0.3, batch_first=True),
            num_layers=num_layers
        )
        self.norm = nn.LayerNorm(hidden_size)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, images):
        pixel_values = images["pixel_values"].to(device)
        gray_features = self.vit(pixel_values).last_hidden_state.mean(dim=1)
        gray_features = self.gray_fc(gray_features)
        gray_features = gray_features.unsqueeze(1)
        fused_features = self.fusion_transformer(gray_features)
        fused_features = fused_features.squeeze(1)
        fused_features = self.norm(fused_features)
        output = self.fc(fused_features)
        return output

def test_metrics(model, test_loader):
    model.eval()
    total_correct = 0
    total_samples = 0
    y_pred_proba = []
    y_true = []
    f1_metric = MulticlassF1Score(num_classes=cfg.class_num).to(device)

    with torch.no_grad():
        for images, labels in test_loader:
            images = {k: v.to(device) for k, v in images.items()}
            labels = labels.to(device)
            with autocast('cuda'):
                outputs = model(images)
            _, preds = torch.max(outputs, 1)
            probabilities = F.softmax(outputs, dim=1)
            y_pred_proba.extend(probabilities.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)
            f1_metric.update(preds, labels)

    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    y_pred = np.argmax(np.array(y_pred_proba), axis=1)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_metric.compute().item()

    print(f'\nMalimg 25-Class Gray Only Results:')
    print(f'Accuracy: {accuracy:.4f}')
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

def test_with_confusion_matrix(model, test_loader, selected_classes, save_dir='./results/mmt_vit/big2015_yz/yz_results_25'):
    model.eval()
    total_correct = 0
    total_samples = 0
    y_pred_proba = []
    y_true = []
    f1_metric = MulticlassF1Score(num_classes=cfg.class_num).to(device)

    with torch.no_grad():
        for images, labels in test_loader:
            images = {k: v.to(device) for k, v in images.items()}
            labels = labels.to(device)
            with autocast('cuda'):
                outputs = model(images)
            _, preds = torch.max(outputs, 1)
            probabilities = F.softmax(outputs, dim=1)
            y_pred_proba.extend(probabilities.cpu().numpy())
            y_true.extend(labels.cpu().numpy())
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)
            f1_metric.update(preds, labels)

    accuracy = total_correct / total_samples if total_samples > 0 else 0.0
    y_pred = np.argmax(np.array(y_pred_proba), axis=1)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_metric.compute().item()

    conf_matrix = confusion_matrix(y_true, y_pred, labels=range(cfg.class_num))
    conf_matrix_norm = confusion_matrix(y_true, y_pred, labels=range(cfg.class_num), normalize='true')

    results = {
        "accuracy": float(accuracy),
        "precision": float(precision),
        "recall": float(recall),
        "f1_score": float(f1),
        "confusion_matrix": conf_matrix.tolist(),
        "normalized_confusion_matrix": conf_matrix_norm.tolist(),
        "class_labels": selected_classes
    }
    os.makedirs(save_dir, exist_ok=True)
    json_path = os.path.join(save_dir, 'malimg_25-class_finetuned_results.json')
    try:
        with open(json_path, 'w') as f:
            json.dump(results, f, indent=4)
        print(f"Test results saved to {json_path}")
    except Exception as e:
        print(f"Error saving JSON results: {e}")

    print(f'\nMalimg 25-Class Gray Only Results:')
    print(f'Accuracy: {accuracy:.4f}')
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    plt.figure(figsize=(12, 10))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=selected_classes, yticklabels=selected_classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix (Malimg 25-Class Gray Only)')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'confusion_matrix_malimg_25class.png'))
    plt.close()

    plt.figure(figsize=(12, 10))
    sns.heatmap(conf_matrix_norm, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=selected_classes, yticklabels=selected_classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Normalized Confusion Matrix (Malimg 25-Class Gray Only)')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'normalized_confusion_matrix_malimg_25class.png'))
    plt.close()

    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}

def get_class_folders(malimg_root):
    if not os.path.exists(malimg_root):
        raise FileNotFoundError(f"malimg root directory does not exist: {malimg_root}")

    class_folders = [d for d in os.listdir(malimg_root) if os.path.isdir(os.path.join(malimg_root, d))]

    if len(class_folders) != cfg.class_num:
        print(f"Warning: Found {len(class_folders)} folders, expected {cfg.class_num}")

    class_folders.sort()

    print(f"Found {len(class_folders)} class folders: {class_folders}")
    return class_folders

if __name__ == '__main__':
    malimg_root = './data/big2015_yz/malimg_25/data_in'
    output_root = './data/big2015_yz/malimg_25/data_out'

    print("Organizing malimg data...")
    selected_classes = organize_malimg_data(malimg_root, output_root)
    if len(selected_classes) != 25:
        raise ValueError(f"Expected 25 classes, but found {len(selected_classes)}")

    try:
        full_dataset = MalimgDataset(output_root, selected_classes)
    except ValueError as e:
        print(f"Error: {e}")
        exit(1)

    if len(full_dataset) == 0:
        print("Error: Dataset is empty. Please check the dataset path and class directories.")
        exit(1)

    train_paths, test_paths, train_labels, test_labels = train_test_split(
        full_dataset.image_paths, full_dataset.labels, test_size=0.2, random_state=cfg.SEED, stratify=full_dataset.labels
    )

    train_dataset = MalimgDataset(output_root, selected_classes, image_paths=train_paths, labels=train_labels)
    test_dataset = MalimgDataset(output_root, selected_classes, image_paths=test_paths, labels=test_labels)

    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

    print(f"Train dataset: {len(train_dataset)} images, Label distribution: {Counter(train_dataset.labels)}")
    print(f"Test dataset: {len(test_dataset)} images, Label distribution: {Counter(test_dataset.labels)}")
    print(f"Class to index: {train_dataset.class_to_idx}")

    class_counts = Counter(train_dataset.labels)
    class_weights = torch.tensor([1.0 / (max(class_counts.get(i, 1), 1) ** 0.5) for i in range(cfg.class_num)], dtype=torch.float).to(device)

    model = GrayOnlyModel(num_classes=cfg.class_num, hidden_size=768).to(device)

    model_path = './results/mmt_vit/mmt_vit_results/model/gray_only.pth'
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model weights not found at {model_path}")

    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    model_state_dict = model.state_dict()
    state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
    model_state_dict.update(state_dict)
    model.load_state_dict(model_state_dict)
    print(f"Loaded model from {model_path} ")

    nn.init.xavier_uniform_(model.fc.weight)
    nn.init.zeros_(model.fc.bias)

    for name, param in model.vit.named_parameters():
        if "encoder.layer" in name and int(name.split('.')[2]) >= 8:
            param.requires_grad = True
        else:
            param.requires_grad = False

    for name, param in model.named_parameters():
        if "gray_fc" in name or "fc" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

    for images, _ in train_loader:
        pixel_values = images["pixel_values"]
        print(f"malimg (train): Shape: {pixel_values.shape}, Mean: {pixel_values.mean().item():.4f}, Std: {pixel_values.std().item():.4f}")
        break

    print("Finetuning on malimg 25-class dataset...")
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
    scheduler = CosineAnnealingLR(optimizer, T_max=50)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    num_epochs = cfg.num_epochs
    best_test_accuracy = 0.0
    patience = 3
    counter = 0

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        model.train()
        for images, labels in train_loader:
            images = {k: v.to(device) for k, v in images.items()}
            labels = labels.to(device)
            optimizer.zero_grad()
            with autocast('cuda'):
                outputs = model(images)
                loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * labels.size(0)
        avg_loss = running_loss / len(train_dataset) if len(train_dataset) > 0 else 0.0
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")

        test_results = test_metrics(model, test_loader)
        test_accuracy = test_results['accuracy']
        print(f"Epoch {epoch+1}/{num_epochs}, Test Accuracy: {test_accuracy:.4f}")

        if test_accuracy > best_test_accuracy:
            best_test_accuracy = test_accuracy
            counter = 0
            save_path = './results/mmt_vit/big2015_yz/yz_results_25/malimg_25class_finetuned_best.pth'
            torch.save({'model_state_dict': model.state_dict()}, save_path)
            print(f"Saved best model with test accuracy {best_test_accuracy:.4f} at epoch {epoch+1}")
        scheduler.step()

    checkpoint = torch.load('./results/mmt_vit/big2015_yz/yz_results_25/malimg_25class_finetuned_best.pth', map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    print("Final testing on malimg 25-class dataset...")
    final_results = test_with_confusion_matrix(model, test_loader, selected_classes, save_dir='./results/mmt_vit/big2015_yz/yz_results_25')

    with open('./results/mmt_vit/big2015_yz/yz_results_25/malimg_25class_finetuned_results.json', 'w') as f:
        json.dump(final_results, f, indent=4)