import os
import numpy as np
import argparse
import ssl

ssl._create_default_https_context = ssl._create_stdlib_context

import torch
from torch.utils.data import DataLoader

import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

from medmnist import (
    OrganMNIST3D, 
    NoduleMNIST3D, 
    AdrenalMNIST3D, 
    FractureMNIST3D, 
    VesselMNIST3D, 
    SynapseMNIST3D
)

from src.vi_model import ViHOT

from torch.utils.data import Sampler


def func(batch):
    x, y = batch[0]
    return torch.from_numpy(x).transpose(0, 1), torch.from_numpy(y)

class BalancedBatchSampler(Sampler):
    def __init__(self, dataset, n_classes, n_samples):
        self.labels = dataset.labels
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.batch_size = n_classes * n_samples
        self.indices = {label: np.where(self.labels == label)[0] for label in range(self.n_classes)}
        self.used_indices = {label: set() for label in range(self.n_classes)}
        
    def __iter__(self):
        count = 0
        while count < len(self.labels):
            batch = []
            for label in range(self.n_classes):
                unused = set(self.indices[label]) - self.used_indices[label]
                if len(unused) < self.n_samples:
                    self.used_indices[label] = set()
                    unused = set(self.indices[label])
                indices = np.random.choice(list(unused), self.n_samples, replace=False)
                batch.extend(indices)
                self.used_indices[label].update(indices)
            yield batch
            count += self.batch_size

    def __len__(self):
        return len(self.labels) // self.batch_size

# import random
# L.seed_everything(random.randint(0, 10000))


print(f'Cuda {"IS" if torch.cuda.is_available() else "is NOT"} Available')
print('NUM CUDA DEVICES:', torch.cuda.device_count())


def main():
    parser = argparse.ArgumentParser(
        "Med mnist", add_help=False
    )
    parser.add_argument(
        "--name", default='organ', type=str, help="name of the dataset"
    )
    parser.add_argument(
        "--hidden", default=1, type=int, help="number of hidden dimensions"
    )
    parser.add_argument(
        "--patch_size", default=5, type=int, help="patch_size"
    )
    parser.add_argument(
        "--nhead", default=1, type=int, help="number of heads"
    )
    parser.add_argument(
        "--linear", default=1, type=int, help="linear attention"
    )
    parser.add_argument(
        "--nblocks", default=1, type=int, help="number of blocks"
    )
    parser.add_argument(
        "--dropout", default=0., type=float, help="dropout"
    )
    parser.add_argument(
        "--pooling", default='mean', type=str, help="pooling method"
    )
    
    args = parser.parse_args()
    
    data_cls = {
        'fracture' : FractureMNIST3D,
        'organ' : OrganMNIST3D, 
        'nodule' : NoduleMNIST3D, 
        'adrenal' : AdrenalMNIST3D, 
        'vessel' : VesselMNIST3D, 
        'synapse' : SynapseMNIST3D
    }[args.name]
    train_dataset = data_cls(split="train", download=True, root='data/medmnist')
    val_dataset = data_cls(split="val", download=True, root='data/medmnist')
    test_dataset = data_cls(split="test", download=True, root='data/medmnist')
    num_classes = len(train_dataset.info['label'])
    
    train_dataloader = DataLoader(
        train_dataset, 
        collate_fn=func,
        sampler=BalancedBatchSampler(train_dataset, n_classes=num_classes, n_samples=33 // num_classes)
    )
    val_dataloader = DataLoader(val_dataset, batch_size=32)
    test_dataloader = DataLoader(test_dataset, batch_size=32)

    weight = np.histogram(train_dataset.labels, bins=num_classes)[0]
    weight = weight.sum() / weight
    weight = torch.from_numpy(weight / weight.sum()).cuda().float()
    
    
    model = ViHOT(
        d_hidden=args.hidden,
        n_blocks=args.nblocks, 
        n_head=args.nhead, 
        n_class=num_classes,
        dropout=args.dropout,
        patch_size=args.patch_size,
        pooling=args.pooling,
        use_linear_att=bool(args.linear),
        feature_map='SMReg',
        lr=1e-3,
        ce_weight=weight
    )

    name = f'patch_size={args.patch_size}-d_hidden={args.hidden}-n_blocks={args.nblocks}-n_head={args.nhead}-dropout={args.dropout}-pooling={args.pooling}-linear={args.linear}'
    proj = 'ViHOT-' + args.name
    checkpoint_callback = ModelCheckpoint(
        dirpath=f"experiments/{proj}/weights", 
        filename=name + "-{epoch:02d}-{val_auc:.2f}",
        save_top_k=1, 
        monitor="val_auc", 
        mode='max'
    )
    trainer = L.Trainer(
        max_epochs=100,
        devices=1,
        accelerator="gpu", 
        num_nodes=1,
        callbacks=[checkpoint_callback],
        gradient_clip_val=5.,
        accumulate_grad_batches=4,
        enable_progress_bar=False
    )
    
    trainer.fit(model, train_dataloader, val_dataloader)
    print(trainer.test(ckpt_path='best', dataloaders=test_dataloader))


if __name__ == '__main__':
    main()

