import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from util import CustomDataset,give_batch,getClassificationMetrics,plot_confusion_matrix,plot_accuracy,heatmapVisual
import config as C
from model import EACR
import logging
import time
import numpy as np
import os

def train(model,train_dataloader,test_dataloader,epochs,data_length,device,log_dir):
    logging.basicConfig(filename='trains.log', filemode='w', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    optimizer = torch.optim.Adam(model.parameters(),lr=3e-6)
    criterion = torch.nn.CrossEntropyLoss()
    train_accuracy_logger = logging.getLogger('train_accuracy')
    test_accuracy_logger = logging.getLogger('test_accuracy')
    train_accuracy_handler = logging.FileHandler(os.path.join(log_dir, 'train_accuracy.log'), mode='w')
    test_accuracy_handler = logging.FileHandler(os.path.join(log_dir, 'test_accuracy.log'), mode='w')
    train_accuracy_logger.addHandler(train_accuracy_handler)
    test_accuracy_logger.addHandler(test_accuracy_handler)
    train_accuracy_logger.setLevel(logging.INFO)
    test_accuracy_logger.setLevel(logging.INFO)
    train_accuracy_list = []
    test_accuracy_list = []
    for epoch in range(epochs):
        print("---------------第{}轮训练开始:-------------------".format(epoch+1))
        start_time = time.time()
        model.train()
        epoch_loss = 0
        for x_train, y_train in train_dataloader:
            x_train = x_train.to(device)
            y_train = y_train.to(device)
            optimizer.zero_grad()
            predictions = model(x_train)
            loss = criterion(predictions,y_train)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        end_time = time.time()

        model.eval()
        with torch.no_grad():
            train_accuracy,train_precision,train_recall,train_f1,train_specificity,train_auc ,conf_matrix =getClassificationMetrics(model, train_dataloader, device, log_dir, prefix='train')
            test_accuracy,test_precision,test_recall,test_f1,test_specificity,test_auc ,conf_matrix =getClassificationMetrics(model, test_dataloader, device, log_dir, prefix='test')
        train_accuracy_list.append(train_accuracy)
        test_accuracy_list.append(test_accuracy)
        train_accuracy_logger.info(f'Epoch {epoch + 1}, train_accuracy: {train_accuracy}')
        test_accuracy_logger.info(f'Epoch {epoch + 1}, test_accuracy: {test_accuracy}')
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(train_dataloader)}, spend_time:{end_time - start_time:.2f} s')
        print(f'train_accuracy: {train_accuracy}, test_accuracy: {test_accuracy}')
        logging.info('训练轮数:' + str(epoch) +' \t'+'loss:'+ str(loss.item())+ '\t'+'cost time:' +str(end_time - start_time))
        # if (epoch+1) % 10 == 0:
            # test_accuracy_list.append(test_accuracy)
        if (epoch+1) % 10 == 0:
            torch.save(model, "./ckpt/model_{}.pth".format(epoch+1))
    #heatmapVisual(model,train_dataloader,data_length,device,log_dir)
    plot_accuracy(train_accuracy_list, test_accuracy_list, epochs, log_dir)
    plot_confusion_matrix(conf_matrix, log_dir, 'test')
    print(f'test_accuracy: {test_accuracy}')
    print(f'test_precision: {test_precision}')
    print(f'test_recall: {test_recall}')
    print(f'test_f1: {test_f1}')
    print(f'test_specificity: {test_specificity}')
    print(f'test_auc: {test_auc}')
    


def test(dataloader,data_length,device,log_dir):
    model = torch.load('./ckpt/model_2000.pth')
    model.eval()
    accuracy, precision, recall, f1, specificity, auc, conf_matrix = getClassificationMetrics(model, dataloader, device, log_dir, 'test')
    #heatmapVisual(model,dataloader,data_length,device,log_dir)
    plot_confusion_matrix(conf_matrix, log_dir, 'test')
    print(f'accuracy: {accuracy}')
    print(f'precision: {precision}')
    print(f'recall: {recall}')
    print(f'f1: {f1}')
    print(f'specificity: {specificity}')
    print(f'auc: {auc}')

if __name__=="__main__":
    device = C.device
    batch_size = C.batch_size
    epochs = C.epochs
    data_length = C.data_length
    path = C.path
    log_dir = C.resultpath
    test_dir = C.test_dir
    embed_size=C.embed_size
    inport_length = C.Resolution_data
    X_train, X_test, y_train, y_test = give_batch(path)
    train_x=torch.tensor(X_train)
    train_y=torch.tensor(y_train)
    train_dataset = CustomDataset(train_x, train_y)
    test_dataset = CustomDataset(torch.tensor(X_test), torch.tensor(y_test))
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    model = EACR(inport_length,embed_size)
    #model = EACR(embed_size,inport_length,data_length)
    model.to(device)
    train(model,train_dataloader,test_dataloader,epochs,data_length,device,log_dir)
    test(test_dataloader,data_length,device,test_dir)