from argparse import Namespace
import os
import pdb
from libauc.losses import AUCMLoss, CrossEntropyLoss
from libauc.optimizers import PESG, Adam
from libauc.models import densenet121 as DenseNet121
# from libauc.datasets import CheXpert
from Datasets.CheXpert.chexpert import CheXpert
import numpy as np
import torch

import torch 
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

from joblib.externals.loky.backend.context import get_context
import gc

CUDA_AVAILABLE = torch.cuda.is_available()

def load_model(model_id, load_path):
    trained_model = DenseNet121(pretrained=False, last_activation=None, activations='relu', num_classes=1)
    model_path = os.path.join(load_path, f'deep_auc_model{model_id}.pth') 
    if not os.path.exists(model_path):
        print("Model {} does not exist".format(model_path))
        return None
    else:
        trained_model.load_state_dict(torch.load(os.path.join(load_path, f'deep_auc_model{model_id}.pth')))
        trained_model.cuda()
    return trained_model


def generate_model_and_fit_deepAUC(train_indices, validation_indices, save_path="./models", batch_size=32, lr=0.05, margin=1.0, epoch_decay=2e-3, weight_decay=1e-5, eval_every=10, epochs=2, model_id=0):

    def _validate(best_val_auc):
        model.eval()
        with torch.no_grad():    
            test_pred = []
            test_true = [] 
            for jdx, data in enumerate(valid_loader):
                print(f"valid {jdx}")
                test_data, test_labels = data
                test_data = test_data.cuda() if CUDA_AVAILABLE else test_data
                y_pred = model(test_data)
                test_pred.append(y_pred.cpu().detach().numpy())
                test_true.append(test_labels.numpy())
            
            test_true = np.concatenate(test_true)
            test_pred = np.concatenate(test_pred)
            val_auc_mean =  roc_auc_score(test_true, test_pred) 
            model.train()

            if best_val_auc < val_auc_mean:
                best_val_auc = val_auc_mean
                torch.save(model.state_dict(), os.path.join(save_path, f'deep_auc_model{model_id}.pth'))
            print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc ))
        
        return best_val_auc

    if isinstance(train_set, torch.utils.data.dataset.Subset):
        train_set.dataset.image_loader_mode()
    else:
        train_set.image_loader_mode()

    train_subset = torch.utils.data.Subset(train_set, train_indices)
    train_loader = torch.utils.data.DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2, multiprocessing_context=get_context('loky'))

    if validation_indices is not None:
        valid_subset = torch.utils.data.Subset(train_set, validation_indices)
        valid_loader = torch.utils.data.DataLoader(valid_subset, batch_size=batch_size, shuffle=True, num_workers=2, multiprocessing_context=get_context('loky'))

    model = DenseNet121(pretrained=True, last_activation=None, activations='relu', num_classes=1)
    model = model.cuda() if CUDA_AVAILABLE else model

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # define loss & optimizer
    loss_fn = AUCMLoss()
    optimizer = PESG(model, 
                    loss_fn=loss_fn, 
                    lr=lr, 
                    margin=margin, 
                    epoch_decay=epoch_decay, 
                    weight_decay=weight_decay)

    best_val_auc = 0
    for epoch in range(epochs):
        if epoch > 0:
            optimizer.update_regularizer(decay_factor=10)
        for idx, data in enumerate(train_loader):
            print(f"train {idx}")
            train_data, train_labels = data
            if CUDA_AVAILABLE:
                train_data, train_labels = train_data.cuda(), train_labels.cuda()
            y_pred = model(train_data)
            y_pred = torch.sigmoid(y_pred)
            loss = loss_fn(y_pred, train_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # validation
            if validation_indices is not None and (idx +1) % eval_every == 0:
                best_val_auc = _validate(best_val_auc)
    
    # one last round of validation                
    if validation_indices is not None:
        _validate(best_val_auc)
    else:
        torch.save(model.state_dict(), os.path.join(save_path, f'deep_auc_model{model_id}.pth'))
    return model
    

class_id = 1 # 0:Cardiomegaly, 1:Edema, 2:Consolidation, 3:Atelectasis, 4:Pleural Effusion 
root = '/data/datasets/CheXpert-v1.0-small/'

train_set = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=True, use_frontal=True, image_size=224, 
mode='train', class_index=class_id, return_sensitive=True)
# train_set = torch.utils.data.Subset(train_set, list(range(0, 1000)))
# train_sampler = torch.utils.data.SubsetRandomSampler(np.arange(0, 1000))
test_set =  CheXpert(csv_path=root+'valid.csv',  image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='valid', class_index=class_id, return_sensitive=True)


def generate_model_and_fit(train_features, train_labels, args=None, model_id=0):
    class SKLearnAPIWrapper:
        def __init__(self, model, batch_size=32, train_validation_split=0.8, subsample=None):
            self.model = model
            self.batch_size = batch_size
            self.train_validation_split = train_validation_split
            self.subsample = subsample

        def predict(self, indices, dataset=None):
            if dataset is None:
                raise ValueError('dataset cannot be None')
            
            # enable returning images (instead of just indices)
            if isinstance(dataset, torch.utils.data.dataset.Subset):
                dataset.dataset.image_loader_mode()
            else:
                dataset.image_loader_mode()

            subset = torch.utils.data.Subset(dataset, indices)
            if self.subsample is not None:
                subset = torch.utils.data.Subset(subset, np.random.choice(len(subset), self.subsample, replace=False))
            
            loader =  torch.utils.data.DataLoader(subset, batch_size=self.batch_size, num_workers=2, shuffle=True, multiprocessing_context=get_context('loky'))

            self.model.eval()
            with torch.no_grad():    
                preds = []
                # pdb.set_trace()
                for jdx, data in enumerate(loader):
                    # breakpoint()
                    print(f"eval {jdx}")
                    data, label = data
                    if CUDA_AVAILABLE:
                        data = data.cuda()
                    y_pred = self.model(data)
                    preds.append(y_pred.cpu().detach().numpy())
                preds = np.concatenate(preds)
            return (preds >= 0.5).astype(int).squeeze()
        
        def predict_proba(self, indices, dataset=None):
            if dataset is None:
                raise ValueError('dataset cannot be None')

            # enable returning images (instead of just indices)
            if isinstance(dataset, torch.utils.data.dataset.Subset):
                dataset.dataset.image_loader_mode()
            else:
                dataset.image_loader_mode()

            subset = torch.utils.data.Subset(dataset, indices)
            if self.subsample is not None:
                subset = torch.utils.data.Subset(subset, np.random.choice(len(subset), self.subsample, replace=False))
            
            loader =  torch.utils.data.DataLoader(subset, batch_size=self.batch_size, num_workers=2, shuffle=True, multiprocessing_context=get_context('loky'))

            self.model.eval()
            with torch.no_grad():    
                preds = []
                for jdx, data in enumerate(loader):
                    data, label = data
                    if CUDA_AVAILABLE:
                        data = data.cuda()
                    y_pred = self.model(data)
                    preds.append(y_pred.cpu().detach().numpy())
                preds = np.concatenate(preds)
            return preds.squeeze()

        def score(self, indices, dataset=None):
            if dataset is None:
                raise ValueError('dataset cannot be None')

            # enable returning images (instead of just indices)
            dataset.image_loader_mode()

            subset = torch.utils.data.Subset(dataset, indices)
            loader =  torch.utils.data.DataLoader(subset, batch_size=self.batch_size, num_workers=2, shuffle=True, multiprocessing_context=get_context('loky'))

            self.model.eval()
            with torch.no_grad():    
                preds = []
                truths = [] 
                for jdx, data in enumerate(loader):
                    data, label = data
                    if CUDA_AVAILABLE:
                        data = data.cuda()
                    y_pred = self.model(data)
                    preds.append(y_pred.cpu().detach().numpy())
                    truths.append(label.numpy())
                
                truths = np.concatenate(truths)
                preds = np.concatenate(preds)
                val_auc =  roc_auc_score(truths, preds)
            return val_auc
        
        def save_model(self, path):
            torch.save(self.model.state_dict(), path)

        def destruct(self):
            del self.model                
            gc.collect()
            torch.cuda.empty_cache()
    
    # DeepAUC-specific arguments
    d_args = Namespace(**args.deep_auc_dict)

    assert len(train_features.shape) == 1 # train_features should only contain indices
    if d_args.eval_every != -1:
        train_indices, validation_indices = train_test_split(train_features, train_size=d_args.train_validation_split, random_state=args.seed)
    else:
        train_indices = train_features
        validation_indices = None

    # load and/or fit model
    if args.skip in ["training_teachers", "voting"] and isinstance(model_id, int):
        print("Loading pretrained teacher model {}".format(model_id))
        loaded_model= load_model(model_id, args.trained_model_path)
    elif args.skip == ["training_all", "voting"]:
        print("Loading pretrained model {}".format(model_id))
        loaded_model = load_model(model_id, args.trained_model_path)
    else:
        loaded_model = None

    if loaded_model is None:
        print("Training model {}".format(model_id))
        trained_model = generate_model_and_fit_deepAUC(train_indices, validation_indices, 
                                                    save_path=args.trained_model_path, 

                                                    batch_size=d_args.batch_size, 
                                                    lr=d_args.lr, margin=d_args.margin, 
                                                    epoch_decay=d_args.epoch_decay, 
                                                    weight_decay=d_args.weight_decay, 
                                                    eval_every=d_args.eval_every, 
                                                    epochs=d_args.epochs,
                                                    model_id=model_id)
    else:
        trained_model = loaded_model
        
        
    # wrap it in a sklearn API to enable easy inference
    model = SKLearnAPIWrapper(trained_model, batch_size=d_args.batch_size, train_validation_split=d_args.train_validation_split)
    return model