import dill
import numpy as np
import argparse
from collections import defaultdict
from sklearn.metrics import jaccard_score, precision_score, recall_score, f1_score, accuracy_score, average_precision_score
from torch.optim import Adam
import os

import torch
import time
from model import MambaHealthModel

from util import llprint, ddi_rate_score, get_n_params
import torch.nn.functional as F

torch.manual_seed(1203)
np.random.seed(2048)

# Setting
model_name = 'MambaHealth1'
resume_path = 'saved/MambaHealth/best.model'

if not os.path.exists(os.path.join("saved", model_name)):
    os.makedirs(os.path.join("saved", model_name))

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--Test', action='store_true', default=True, help="test mode")
parser.add_argument('--model_name', type=str, default=model_name, help="model name")
parser.add_argument('--resume_path', type=str, default=resume_path, help='resume path')
parser.add_argument('--lr', type=float, default=3e-5, help='learning rate')
parser.add_argument('--target_ddi', type=float, default=0.06, help='target ddi')
parser.add_argument('--kp', type=float, default=0.05, help='coefficient of P signal')
parser.add_argument('--dim', type=int, default=64, help='dimension')

args = parser.parse_args()

# Evaluate
def multi_label_metric(y_gt, y_pred, y_pred_prob):
    adm_ja = jaccard_score(y_gt, y_pred, average='samples')
    adm_prauc = average_precision_score(y_gt, y_pred_prob, average='samples')
    adm_avg_p = precision_score(y_gt, y_pred, average='samples')
    adm_avg_r = recall_score(y_gt, y_pred, average='samples')
    adm_avg_f1 = f1_score(y_gt, y_pred, average='samples')
    adm_accuracy = accuracy_score(y_gt.flatten(), y_pred.flatten())  # Flatten for accuracy calculation

    adm_macro_p = precision_score(y_gt, y_pred, average='macro')
    adm_macro_r = recall_score(y_gt, y_pred, average='macro')
    adm_macro_f1 = f1_score(y_gt, y_pred, average='macro')
    adm_micro_p = precision_score(y_gt, y_pred, average='micro')
    adm_micro_r = recall_score(y_gt, y_pred, average='micro')
    adm_micro_f1 = f1_score(y_gt, y_pred, average='micro')

    return adm_ja, adm_prauc, adm_avg_p, adm_avg_r, adm_avg_f1, adm_accuracy, adm_macro_p, adm_macro_r, adm_macro_f1, adm_micro_p, adm_micro_r, adm_micro_f1

def eval(model, data_eval, voc_size, epoch):
    model.eval()

    smm_record = []
    ja, prauc, avg_p, avg_r, avg_f1, accuracy, macro_p, macro_r, macro_f1, micro_p, micro_r, micro_f1 = [[] for _ in range(12)]
    med_cnt, visit_cnt = 0, 0

    for step, input in enumerate(data_eval):
        y_gt, y_pred, y_pred_prob, y_pred_label = [], [], [], []
        for adm_idx, adm in enumerate(input):
            target_output, _, _ = model(input[:adm_idx+1])  # Only use the first return value for evaluation

            y_gt_tmp = np.zeros(voc_size[2])
            y_gt_tmp[adm[2]] = 1
            y_gt.append(y_gt_tmp)

            # prediction prob
            target_output = F.sigmoid(target_output).detach().cpu().numpy()
            y_pred_prob.append(target_output[0])
            
            # prediction med set
            y_pred_tmp = target_output.copy()
            y_pred_tmp[y_pred_tmp >= 0.5] = 1
            y_pred_tmp[y_pred_tmp < 0.5] = 0
            y_pred.append(y_pred_tmp[0])

            # prediction label
            y_pred_label_tmp = np.where(y_pred_tmp[0] == 1)[0]
            y_pred_label.append(sorted(y_pred_label_tmp))
            visit_cnt += 1
            med_cnt += len(y_pred_label_tmp)

        smm_record.append(y_pred_label)
        metrics = multi_label_metric(np.array(y_gt), np.array(y_pred), np.array(y_pred_prob))

        ja.append(metrics[0])
        prauc.append(metrics[1])
        avg_p.append(metrics[2])
        avg_r.append(metrics[3])
        avg_f1.append(metrics[4])
        accuracy.append(metrics[5])
        macro_p.append(metrics[6])
        macro_r.append(metrics[7])
        macro_f1.append(metrics[8])
        micro_p.append(metrics[9])
        micro_r.append(metrics[10])
        micro_f1.append(metrics[11])
        llprint('\rtest step: {} / {}'.format(step, len(data_eval)))

    # DDI rate
    ddi_rate = ddi_rate_score(smm_record, path='../data/ddi_A_final.pkl')

    # Average medication count
    avg_med = float(med_cnt) / float(visit_cnt)

    llprint('\nDDI Rate: {:.4f}, Jaccard: {:.4f}, PRAUC: {:.4f}, AVG_PRC: {:.4f}, AVG_RECALL: {:.4f}, AVG_F1: {:.4f}, Accuracy: {:.4f}, Macro_P: {:.4f}, Macro_R: {:.4f}, Macro_F1: {:.4f}, Micro_P: {:.4f}, Micro_R: {:.4f}, Micro_F1: {:.4f}, AVG_MED: {:.4f}\n'.format(
        float(ddi_rate), np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(accuracy), np.mean(macro_p), np.mean(macro_r), np.mean(macro_f1), np.mean(micro_p), np.mean(micro_r), np.mean(micro_f1), avg_med
    ))

    return ddi_rate, np.mean(ja), np.mean(prauc), np.mean(avg_p), np.mean(avg_r), np.mean(avg_f1), np.mean(accuracy), np.mean(macro_p), np.mean(macro_r), np.mean(macro_f1), np.mean(micro_p), np.mean(micro_r), np.mean(micro_f1), avg_med

def compute_loss(model, seq_input, adm, voc_size, device, ddi_rate_score, args):
    loss = 0
    loss_bce_target = np.zeros((1, voc_size[2]))
    loss_bce_target[:, adm[2]] = 1

    loss_multi_target = np.full((1, voc_size[2]), -1)
    for idx, item in enumerate(adm[2]):
        loss_multi_target[0][idx] = item

    result, loss_ddi, loss_ehr = model(seq_input)

    loss_bce = F.binary_cross_entropy_with_logits(result, torch.FloatTensor(loss_bce_target).to(device))
    loss_multi = F.multilabel_margin_loss(F.sigmoid(result), torch.LongTensor(loss_multi_target).to(device))

    result = F.sigmoid(result).detach().cpu().numpy()[0]
    result[result >= 0.5] = 1
    result[result < 0.5] = 0
    y_label = np.where(result == 1)[0]
    current_ddi_rate = ddi_rate_score([[y_label]], path='../data/ddi_A_final.pkl')

    if current_ddi_rate <= args.target_ddi:
        loss = 0.95 * loss_bce + 0.05 * loss_multi + loss_ddi + loss_ehr
    else:
        beta = min(0, 1 + (args.target_ddi - current_ddi_rate) / args.kp)
        loss = beta * (0.95 * loss_bce + 0.05 * loss_multi) + (1 - beta) * (loss_ddi + loss_ehr)
    
    return loss


def main():
    # Load data
    data_path = '../data/records_final.pkl'
    voc_path = '../data/voc_final.pkl'

    ehr_adj_path = '../data/ehr_adj_final.pkl'
    ddi_adj_path = '../data/ddi_A_final.pkl'
    
    device = torch.device('cuda')

    ehr_adj = dill.load(open(ehr_adj_path, 'rb'))
    ddi_adj = dill.load(open(ddi_adj_path, 'rb'))
    data = dill.load(open(data_path, 'rb'))

    voc = dill.load(open(voc_path, 'rb'))
    diag_voc, pro_voc, med_voc = voc['diag_voc'], voc['pro_voc'], voc['med_voc']

    split_point = int(len(data) * 4/5)
    data_train = data[:split_point]
    eval_len = int(len(data[split_point:]) / 2)
    data_test = data[split_point:split_point + eval_len]
    data_eval = data[split_point+eval_len:]

    voc_size = (len(diag_voc.idx2word), len(pro_voc.idx2word), len(med_voc.idx2word))
    model = MambaHealthModel(
        vocab_size=voc_size, 
        ddi_adj=ddi_adj, 
        ehr_adj=ehr_adj, 
        emb_dim=args.dim, 
        d_state=32, 
        d_conv=4, 
        expand=2, 
        num_layers=1, 
        dropout_prob=0.5, 
        num_heads=8, 
        device=device
    ).to(device=device)

    if args.Test:
        model.load_state_dict(torch.load(open(args.resume_path, 'rb')))
        model.to(device=device)
        tic = time.time()

        # ddi_list, ja_list, prauc_list, f1_list, med_list = [], [], [], [], []
        result = []
        print("\n----- 10 rounds of bootstrapping test -----")
        for _ in range(10):
            test_sample = np.random.choice(range(len(data_test)), round(len(data_test) * 0.8), replace=True)
            test_sample = [data_test[i] for i in test_sample]
            metrics = eval(model, test_sample, voc_size, 0)
            result.append(metrics)
        
        result = np.array(result)
        mean = result.mean(axis=0)
        std = result.std(axis=0)

        outstring = ""
        for m, s in zip(mean, std):
            outstring += "{:.4f} $\pm$ {:.4f} & ".format(m, s)

        print("\n----- final (mean $\pm$ std) -----")
        print(outstring)

        print('Test time: {}'.format(time.time() - tic))
        return 

    model.to(device=device)
    print('Parameters:', get_n_params(model))

    optimizer = Adam(list(model.parameters()), lr=args.lr)

    history = defaultdict(list)
    best_epoch, best_ja = 0, 0
    best_model_path = args.resume_path

    EPOCH = 50
    
    for epoch in range(EPOCH):
        tic = time.time()
        print('\nEpoch {} --------------------------'.format(epoch + 1))
        
        model.train()
        for step, input in enumerate(data_train):
            for idx, adm in enumerate(input):
                seq_input = input[:idx+1]

                loss = compute_loss(model, seq_input, adm, voc_size, device, ddi_rate_score, args)

                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                optimizer.step()

            llprint('\rTraining step: {} / {}'.format(step, len(data_train)))

        print()
        tic2 = time.time() 
        metrics = eval(model, data_eval, voc_size, epoch)
        ddi_rate, ja, prauc, avg_p, avg_r, avg_f1, accuracy, macro_p, macro_r, macro_f1, micro_p, micro_r, micro_f1, avg_med = metrics
        print('Training time: {}, Evaluation time: {}'.format(time.time() - tic, time.time() - tic2))

        history['ja'].append(ja)
        history['ddi_rate'].append(ddi_rate)
        history['avg_p'].append(avg_p)
        history['avg_r'].append(avg_r)
        history['avg_f1'].append(avg_f1)
        history['prauc'].append(prauc)
        history['accuracy'].append(accuracy)
        history['macro_p'].append(macro_p)
        history['macro_r'].append(macro_r)
        history['macro_f1'].append(macro_f1)
        history['micro_p'].append(micro_p)
        history['micro_r'].append(micro_r)
        history['micro_f1'].append(micro_f1)
        history['med'].append(avg_med)

        if epoch >= 5:
            print('DDI: {:.4f}, Med: {:.4f}, Jaccard: {:.4f}, F1: {:.4f}, PRAUC: {:.4f}, Accuracy: {:.4f}, Macro_P: {:.4f}, Macro_R: {:.4f}, Macro_F1: {:.4f}, Micro_P: {:.4f}, Micro_R: {:.4f}, Micro_F1: {:.4f}'.format(
                np.mean(history['ddi_rate'][-5:]),
                np.mean(history['med'][-5:]),
                np.mean(history['ja'][-5:]),
                np.mean(history['avg_f1'][-5:]),
                np.mean(history['prauc'][-5:]),
                np.mean(history['accuracy'][-5:]),
                np.mean(history['macro_p'][-5:]),
                np.mean(history['macro_r'][-5:]),
                np.mean(history['macro_f1'][-5:]),
                np.mean(history['micro_p'][-5:]),
                np.mean(history['micro_r'][-5:]),
                np.mean(history['micro_f1'][-5:])
                ))

        current_model_path = os.path.join('saved', args.model_name, 
            'Epoch_{}_TARGET_{:.2f}_JA_{:.4f}_DDI_{:.4f}.model'.format(epoch, args.target_ddi, ja, ddi_rate))
        torch.save(model.state_dict(), open(current_model_path, 'wb'))

        if epoch != 0 and best_ja < ja:
            best_epoch = epoch
            best_ja = ja
            # best_model_path = current_model_path 
            best_model_path = os.path.join('saved', args.model_name, 
            'Epoch_{}_TARGET_{:.2f}_DDI_{:.4f}.model'.format(epoch, args.target_ddi, ddi_rate)) 
            args.resume_path = best_model_path  
            print(f"Saving best model to: {best_model_path}")
            torch.save(model.state_dict(), open(best_model_path, 'wb'))

        print('Best epoch: {}'.format(best_epoch))

    dill.dump(history, open(os.path.join('saved', args.model_name, 'history_{}.pkl'.format(args.model_name)), 'wb'))

    if best_model_path:
        model.load_state_dict(torch.load(open(best_model_path, 'rb')))
        print("Loaded best model from epoch: {}".format(best_epoch))
        eval(model, data_test, voc_size, best_epoch)  

if __name__ == '__main__':
    main()
