import torch
from torch.utils.data import DataLoader
import time
import os
import tqdm
from dataset import *
from models import *
from train_val_test_func import *
import torch.optim as optim
from TDHNODE import *
from utils import *

# training parameters
num_epochs = 200
learning_rate = 1e-4
weight_decay = 1e-6
batch_size = 128
patience = 5
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# model parameters
embedding_all_dim = 128
embedding_time_dim = 128
num_blocks = 12
n_head = 8
dropout = 0.1
forward_expansion = 4

start_time = str(time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())))
log_path = f"./log/{start_time}" + ".log"
logging.basicConfig(filename=log_path, filemode='a', datefmt='%H:%M:%S', level=logging.INFO,
                    format='%(asctime)s: \n%(message)s')
logging.info(f'''Hyper-parameter:
NUM_EPOCHS: {num_epochs}
LEARNING_RATE: {learning_rate}
WEIGHT_DECAY: {weight_decay}
BATCH_SIZE: {batch_size}
PATIENCE: {patience}
DEVICE: {device}
EMBEDDING_ALL_DIM: {embedding_all_dim}
EMBEDDING_TIME_DIM: {embedding_time_dim}
NUM_BLOCKS: {num_blocks}
N_HEAD: {n_head}
DROPOUT: {dropout}
FORWARD_EXPANSION: {forward_expansion}
''')

# load data
dataset_name = 'Hospital_2_12/'
Hospital_input_source = './data/' + dataset_name
MIMIC_input_source = './data/MIMIC/slide_window_20/'
input_source = MIMIC_input_source
train_set = Hyper_Graph_Dataset_biomarker(input_source, 'train')
val_set = Hyper_Graph_Dataset_biomarker(input_source, 'val')
test_set = Hyper_Graph_Dataset_biomarker(input_source, 'test')

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(val_set, batch_size=batch_size)
test_dataloader = DataLoader(test_set, batch_size=batch_size)


model = TDHNODE(num_biomarkers=21, num_risk_factors=34, hidden_dim=embedding_all_dim)


logging.info(f"Model Architecture:\n{model}\n")

model.to(device)

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

# Training and validation
best_val_loss = float('inf')
early_stop_counter = 0

loss_history, acc_history = [], []
valid_loss_history, valid_acc_history = [], []
if input_source == Hospital_input_source:
	# for private
	feature_counts = torch.tensor(
			[8892, 1565, 1529, 1298, 115, 64, 16, 397, 671, 1142, 258, 241, 258, 587, 488, 7308, 1414, 195, 165, 312,
			 37])
if input_source == MIMIC_input_source:
	# for MIMIC
	feature_counts = torch.tensor(
			[222,398,245,249,13,18,2,179,274,187,43,24,41,299,41,212,61,115,188,3,242]
	)

feature_weight = 1 / torch.log(feature_counts + 1)
feature_weight = feature_weight / feature_weight.mean()

if input_source == Hospital_input_source:
	# for private
	pos_weight = torch.tensor([  1.5047,  13.1491,  12.9052,  16.5326, 321.6000, 267.8333, 402.2500,
	         61.0385,  34.8444,  18.2024,  72.3182,  79.6500,  88.6111,  36.5116,
	         41.4474,   2.2004,  26.3390,  88.6111, 160.3000,  63.5200, 536.6667])
if input_source == MIMIC_input_source:
	# for MIMIC
	pos_weight = torch.tensor([35,   26,   23,   46, 1108,  720, 4805,   42,   38,   56,  183,
	        575,  334,   26,  326,   56,  227,  105,   48, 4805,   54])

pos_weight = pos_weight.clamp(min=10.0, max=50.0)

for epoch in tqdm(range(num_epochs), desc="Epoch"):
	print(f"Epoch {epoch + 1}/{num_epochs}")
	train_loss, train_acc, train_multilabel_acc, all_embeddings = train_model(model, train_dataloader, optimizer, device, feature_weight, pos_weight)
	loss_history.append(train_loss)
	acc_history.append(train_acc)
	avg_loss, avg_acc, df_results, all_embeddings_test = test_model(model, val_dataloader, device, feature_weight, pos_weight)
	val_loss = avg_loss
	val_acc = avg_acc
	valid_loss_history.append(val_loss)
	valid_acc_history.append(val_acc)
	scheduler.step(val_loss)

	print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}\n")
	print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}\n")

	logging.info(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}")
	logging.info(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}")

	if val_loss < best_val_loss:
		best_val_loss = val_loss
		early_stop_counter = 0
	else:
		early_stop_counter += 1

	if early_stop_counter >= patience:
		print("Early stopping triggered.")
		logging.info("Early stopping triggered.")
		break

	curr_time = str(time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())))
	model_saved_path = "./model_saved/" + start_time + "/"
	os.makedirs(model_saved_path, exist_ok=True)
	torch.save(model.state_dict(), model_saved_path + f"{embedding_all_dim}_{curr_time}.pth")
	logging.info(f"Model saved at {model_saved_path}{embedding_all_dim}_{curr_time}.pth")

embedding_path = os.path.join("./embedding_saved/", f"patient_embedding_{start_time}.npy")
np.save(embedding_path, all_embeddings.cpu().detach().numpy())
logging.info(f"Embeddings saved at {embedding_path}")

logging.info(f"Final train multi-label accuracy: {train_multilabel_acc}")

lines_data = [
		{'x': range(len(loss_history)), 'y': loss_history, 'label': 'train loss', 'style': 'o-r', 'marker_size': 5, 'y_axis': 'left'},
		{'x': range(len(valid_loss_history)), 'y': valid_loss_history, 'label': 'valid loss', 'style': 'x--b', 'marker_size': 6, 'y_axis': 'left'},
		# {'x': range(len(acc_history)), 'y': acc_history, 'label': 'train acc', 'style': 's:g', 'marker_size': 4, 'y_axis': 'right'},
		# {'x': range(len(valid_acc_history)), 'y': valid_acc_history, 'label': 'valid acc', 'style': 'D-.m', 'marker_size': 3, 'y_axis': 'right'}
	]



plot_path = f"plot/{dataset_name}"
os.makedirs(plot_path, exist_ok=True)

file_name = plot_path + f"/{start_time}_loss.png"

plot_lines(lines_data,
               xlabel="X Axis",
               ylabel_left="Loss",
               ylabel_right="ACC",
               title= model.__class__.__name__ + str(embedding_all_dim),
			   save_string=file_name)

logging.info(f"Loss plot saved at {file_name}")

logging.info("Training completed.")