import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold

from datasets.loader import load_data, DummyDataset
from KFCore import KFCore
from hflayers.activation import HopfieldCore


class ModelWrapper(nn.Module):
    def __init__(self, mode, input_dim, hidden_dim, num_classes, beta, num_states, num_memories, bag_dropout=0.75):
        super().__init__()
        self.mode = mode
        self.bag_dropout = bag_dropout

        self.embedder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )

        # Karcher Flow Models
        if mode == "kf_attention":
            self.core = KFCore(hidden_dim, hidden_dim, hidden_dim, hidden_dim, beta)
        elif mode == "kf_pooling":
            self.core = KFCore(hidden_dim, hidden_dim, hidden_dim, hidden_dim, beta)
            self.static_query = nn.Parameter(torch.randn(1, num_states, hidden_dim) * 0.02)
        elif mode == "kf_layer":
            self.core = KFCore(hidden_dim, hidden_dim, hidden_dim, hidden_dim, beta)
            self.static_key = nn.Parameter(torch.randn(1, num_memories, hidden_dim) * 0.02)
            self.static_value = nn.Parameter(torch.randn(1, num_memories, hidden_dim) * 0.02)

        # Hopfield (HNIAYN) Models
        elif mode == "hf_attention":
            self.core = HopfieldCore(embed_dim=hidden_dim, num_heads=1)
        elif mode == "hf_pooling":
            self.core = HopfieldCore(embed_dim=hidden_dim, num_heads=1, query_as_static=True)
            self.static_query = nn.Parameter(torch.randn(num_states, 1, hidden_dim) * 0.02)
        elif mode == "hf_layer":
            self.core = HopfieldCore(embed_dim=hidden_dim, num_heads=1,
                                     key_as_static=True, value_as_static=True)
            self.static_key = nn.Parameter(torch.randn(num_memories, 1, hidden_dim) * 0.02)
            self.static_value = nn.Parameter(torch.randn(num_memories, 1, hidden_dim) * 0.02)

        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        B, bag_size, input_dim = x.shape

        x = x.view(B * bag_size, input_dim)
        embeds = self.embedder(x)
        embeds = embeds.view(B, bag_size, -1) # Batch-First (B, bag_size, hidden_dim)

        # Apply bag dropout
        if self.training and self.bag_dropout > 0.0:
            mask = torch.rand(B, bag_size, 1, device=embeds.device) > self.bag_dropout
            embeds = embeds * mask.float()

        hniayn_embeds = embeds.transpose(0, 1) # Sequence-First (bag_size, B, hidden_dim)

        if self.mode == "kf_attention":
            z = self.core(embeds, embeds, embeds).mean(dim=1)
        elif self.mode == "kf_pooling":
            q = self.static_query.expand(B, -1, -1)
            z = self.core(q, embeds, embeds).mean(dim=1)
        elif self.mode == "kf_layer":
            k = self.static_key.expand(B, -1, -1)
            v = self.static_value.expand(B, -1, -1)
            z = self.core(embeds, k, v).mean(dim=1)

        elif self.mode == "hf_attention":
            z, *_ = self.core(hniayn_embeds, hniayn_embeds, hniayn_embeds)
            z = z.mean(dim=0)
        elif self.mode == "hf_pooling":
            q = self.static_query.expand(-1, B, -1)
            z, *_ = self.core(q, hniayn_embeds, hniayn_embeds)
            z = z.mean(dim=0)
        elif self.mode == "hf_layer":
            k = self.static_key.expand(-1, B, -1)
            v = self.static_value.expand(-1, B, -1)
            z, *_ = self.core(hniayn_embeds, k, v)
            z = z.mean(dim=0)
        
        z = embeds.mean(dim=1)

        return self.classifier(z).squeeze(-1)


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target, _) in enumerate(train_loader):
        data, target = data.to(device), target.float().to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.binary_cross_entropy_with_logits(output, target)
        loss.backward()
        optimizer.step()

def test(model, device, test_loader):
    model.eval()
    probs, labels = [], []
    with torch.no_grad():
        for data, target, _ in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            probs.extend(torch.sigmoid(output).cpu().numpy())
            labels.extend(target.numpy())

    return roc_auc_score(labels, probs)


def main():
    parser = argparse.ArgumentParser(description="Core MIL")
    parser.add_argument('--model', type=str, default='kf_pooling', 
                        choices=['kf_attention', 'kf_layer', 'kf_pooling', 
                                 'hf_attention', 'hf_layer', 'hf_pooling'])
    parser.add_argument('--dataset', type=str, default='tiger',
                        choices=['tiger', 'fox', 'elephant'])
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.001)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-accel', action='store_true',
                        help='disables accelerator')
    parser.add_argument('--dry-run', action='store_true',
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10000, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', 
                        help='For Saving the current Model')
    parser.add_argument('--multiply', action='store_true')
    args = parser.parse_args()


    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    features, labels = load_data(args)
    features = np.array(features, dtype=object)
    labels = np.array(labels)
    input_dim = features[0].shape[-1]

    model_params = dict(
        input_dim=input_dim,
        hidden_dim=128,
        num_classes=1,
        beta=1/math.sqrt(128),
        num_states=1,
        num_memories=64,
        bag_dropout=0.75
    )

    all_results = []

    skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=args.seed)
    fold_scores = []

    for fold, (train_idx, test_idx) in enumerate(skf.split(features, labels)):
        train_set = DummyDataset(features[train_idx], labels[train_idx])
        test_set = DummyDataset(features[test_idx], labels[test_idx])

        train_loader = DataLoader(
            train_set, batch_size=args.batch_size, shuffle=True,
            collate_fn=train_set.collate
        )
        test_loader = DataLoader(
            test_set, batch_size=args.batch_size, shuffle=False,
            collate_fn=test_set.collate
        )

        model = ModelWrapper(mode=args.model, **model_params).to(device)
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

        for epoch in range(1, args.epochs + 1):
            train(args, model, device, train_loader, optimizer, epoch)
            scheduler.step()

        auc = test(model, device, test_loader)
        fold_scores.append(auc)
        print(f"Fold {fold + 1} AUC: {auc:.4f}")

    mean_auc = np.mean(fold_scores)
    all_results.append(mean_auc)
    print(f"Repetition Mean AUC: {mean_auc:.4f}")

    print("\nFinal Results")
    print("Overall Mean AUC:", np.mean(all_results))

    return np.mean(all_results)


if __name__ == "__main__":
    main()
