import torch
import torch.nn as nn
import time
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from util import CustomDataset,give_batch,heatmapVisual,plot_r2_trends,getRegressionMetrics,mean_and_se
import config as C
from model import MyModel
import logging
import time
import numpy as np
import os
import random
# random.seed(2024)
# np.random.seed(2024)
# torch.manual_seed(2024)
# torch.cuda.manual_seed(2024)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

def train(model,train_dataloader,test_dataloader,epochs,data_langth,device,log_dir):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.MSELoss()
    train_R2_logger = logging.getLogger('train_R2')
    test_R2_logger = logging.getLogger('test_R2')
    train_R2_handler = logging.FileHandler(os.path.join(log_dir, 'train_R2.log'), mode='w')
    test_R2_handler = logging.FileHandler(os.path.join(log_dir, 'test_R2.log'), mode='w')
    train_R2_logger.addHandler(train_R2_handler)
    test_R2_logger.addHandler(test_R2_handler)
    train_R2_logger.setLevel(logging.INFO)
    test_R2_logger.setLevel(logging.INFO)
    train_MSE_list, train_RMSE_list, train_MAE_list, train_R2_list = [], [], [], []
    test_MSE_list, test_RMSE_list, test_MAE_list, test_R2_list = [], [], [], []
    for epoch in range(epochs):
        print("---------------第{}轮训练开始:-------------------".format(epoch+1))
        start_time = time.time()
        model.train()
        epoch_loss = 0
        for train_x,trian_y in train_dataloader:
            train_x = train_x.long().to(device)
            trian_y = trian_y.float().unsqueeze(1).to(device)
            optimizer.zero_grad()
            predictions = model(train_x)
            loss = criterion(predictions, trian_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        end_time = time.time()
        model.eval()
        with torch.no_grad():
            train_MSE, train_RMSE, train_MAE, train_R2 = getRegressionMetrics(model, train_dataloader, device, log_dir, prefix='train')
            test_MSE, test_RMSE, test_MAE, test_R2 = getRegressionMetrics(model, test_dataloader, device, log_dir, prefix='test')
        train_MSE_list.append(train_MSE)
        train_RMSE_list.append(train_RMSE)
        train_MAE_list.append(train_MAE)
        train_R2_list.append(train_R2)
        
        test_MSE_list.append(test_MSE)
        test_RMSE_list.append(test_RMSE)
        test_MAE_list.append(test_MAE)
        test_R2_list.append(test_R2)
        train_R2_logger.info(f'Epoch {epoch + 1}, train_R2: {train_R2}')
        test_R2_logger.info(f'Epoch {epoch + 1}, test_R2: {test_R2}')
        print(f'Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(train_dataloader)}, spend_time:{end_time - start_time:.2f} s')
        print(f'train_R2: {train_R2}, test_R2: {test_R2}')

        if (epoch+1) % 10 == 0:
            torch.save(model, "./ckpt/model_{}.pth".format(epoch+1))
    
    model.eval()
    plot_r2_trends(train_R2_list,log_dir,prefix='train')
    plot_r2_trends(test_R2_list,log_dir,prefix='test')
    ee = 10
    if len(train_R2_list) >= ee:
        # Convert lists to numpy arrays for easy computation
        last_10_train_MSE = np.array(train_MSE_list[-ee:])
        last_10_train_RMSE = np.array(train_RMSE_list[-ee:])
        last_10_train_MAE = np.array(train_MAE_list[-ee:])
        last_10_train_R2 = np.array(train_R2_list[-ee:])
        
        last_10_test_MSE = np.array(test_MSE_list[-ee:])
        last_10_test_RMSE = np.array(test_RMSE_list[-ee:])
        last_10_test_MAE = np.array(test_MAE_list[-ee:])
        last_10_test_R2 = np.array(test_R2_list[-ee:])
        
        # Mean and standard deviation for training metrics
        train_MSE_mean, train_MSE_std = np.mean(last_10_train_MSE), np.std(last_10_train_MSE)
        train_RMSE_mean, train_RMSE_std = np.mean(last_10_train_RMSE), np.std(last_10_train_RMSE)
        train_MAE_mean, train_MAE_std = np.mean(last_10_train_MAE), np.std(last_10_train_MAE)
        train_R2_mean, train_R2_std = np.mean(last_10_train_R2), np.std(last_10_train_R2)
        
        # Mean and standard deviation for test metrics
        test_MSE_mean, test_MSE_std = np.mean(last_10_test_MSE), np.std(last_10_test_MSE)
        test_RMSE_mean, test_RMSE_std = np.mean(last_10_test_RMSE), np.std(last_10_test_RMSE)
        test_MAE_mean, test_MAE_std = np.mean(last_10_test_MAE), np.std(last_10_test_MAE)
        test_R2_mean, test_R2_std = np.mean(last_10_test_R2), np.std(last_10_test_R2)
        file_path = os.path.join(log_dir, 'results.txt')
        with open(file_path, 'w') as file:
            file.write(f'Test MSE: {test_MSE_mean:.4f} ± {test_MSE_std:.4f}\n')
            file.write(f'Test RMSE: {test_RMSE_mean:.4f} ± {test_RMSE_std:.4f}\n')
            file.write(f'Test MAE: {test_MAE_mean:.4f} ± {test_MAE_std:.4f}\n')
            file.write(f'Test R2: {test_R2_mean:.4f} ± {test_R2_std:.4f}\n')
        # Print out the mean ± standard deviation for training and test metrics
        print(f'Train MSE: {train_MSE_mean:.4f} ± {train_MSE_std:.4f}')
        print(f'Train RMSE: {train_RMSE_mean:.4f} ± {train_RMSE_std:.4f}')
        print(f'Train MAE: {train_MAE_mean:.4f} ± {train_MAE_std:.4f}')
        print(f'Train R2: {train_R2_mean:.4f} ± {train_R2_std:.4f}')
        
        print(f'Test MSE: {test_MSE_mean:.4f} ± {test_MSE_std:.4f}')
        print(f'Test RMSE: {test_RMSE_mean:.4f} ± {test_RMSE_std:.4f}')
        print(f'Test MAE: {test_MAE_mean:.4f} ± {test_MAE_std:.4f}')
        print(f'Test R2: {test_R2_mean:.4f} ± {test_R2_std:.4f}')
        
def test(dataloader,data_length,device,log_dir):
    model = torch.load('./ckpt/model_680.pth')
    model.eval()
    MSE,RMSE,MAE,R2 = getRegressionMetrics(model,dataloader,device,log_dir,prefix='test')
    # heatmapVisual(model,dataloader,data_length,device,log_dir)
    print(f'MSE: {MSE}')
    print(f'RMSE: {RMSE}')
    print(f'MAE: {MAE}')
    print(f'R2: {R2}')



if __name__=="__main__":
    device = C.device
    # print(device)
    # exit()
    batch_size = C.batch_size
    epochs = C.epochs
    data_length = C.data_length
    path = C.path
    log_dir = C.resultpath
    test_dir = C.test_dir
    embed_size=C.embed_size
    inport_length = C.Resolution_data
    X_train, X_test, train_y, y_test = give_batch(path)
    train_x=torch.tensor(X_train)
    train_y=torch.tensor(train_y)
    train_dataset = CustomDataset(train_x, train_y)
    test_dataset = CustomDataset(torch.tensor(X_test), torch.tensor(y_test))
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    # model = MyModel(inport_length,embed_size)
    model = MyModel(embed_size,inport_length,data_length)
    model.to(device)
    train(model,train_dataloader,test_dataloader,epochs,data_length,device,log_dir)
    # test(test_dataloader,data_length,device,test_dir)