# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import torch
from utils import *
import torch.optim as optim
from Dataset import *
from predictor import *
from tqdm import tqdm
import time
import matplotlib.colors as mcolors
import os
import torch.optim.lr_scheduler as lr_scheduler
import logging
import json
import datetime


def configure_logging(log_file):
    logging.basicConfig(
                    level=logging.DEBUG,  
                    format='%(asctime)s - %(levelname)s - %(message)s',  
                    filename='/load_balancing.log',  
                    filemode='w'  
                )
    logging.info("Logging started.")


def train(model, train_dataset, train_dataloader, valid_dataset, valid_dataloader,criterion, optimizer, num_epochs, device, log_file, expert_save_path, model_save_path,lambda_mse,lambda_load_balance ,lambda_smooth):
    train_loss_arr = []
    valid_loss_arr = []
    train_mse_loss_arr = []
    valid_mse_loss_arr = []
    best_loss = np.inf
    patience_counter = 0
    patience = 200
    configure_logging(log_file)
    prompt_flag = 1
    criterion = HURSTLoss(epsilon=0.1, 
                            window_size= window_size,
                            num_sample_windows=20,
                            lambda_mse= lambda_mse,
                            lambda_load_balance = lambda_load_balance,
                            lambda_smooth = lambda_smooth
                            )
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, min_lr=1e-5)
    for epoch in range(num_epochs):
        start_time = time.time()  
        model.train()
        avg_train_loss = 0.0
        avg_train_mse_loss = 0.0
        for data, label ,feature_idx in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            data = data.to(device)
            label = label.to(device)
            optimizer.zero_grad()
            output, gate_output, leaf_expert_ids,gate_logits = model(data, prompt_flag)
            train_loss, mse_loss = criterion(output, label, gate_logits,mode='predict')
            
            train_loss.backward()
            optimizer.step()
            avg_train_loss += train_loss.item()
            avg_train_mse_loss += mse_loss.item()
        avg_train_loss /= len(train_dataloader)
        avg_train_mse_loss /= len(train_dataloader)
        train_loss_arr.append(avg_train_loss)
        train_mse_loss_arr.append(avg_train_mse_loss)


        model.eval()
        avg_valid_loss = 0.0
        avg_valid_mse_loss = 0.0
        with torch.no_grad():
            for data, label,feature_idx in valid_dataloader:
                data = data.to(device)
                label = label.to(device)
                outputs, gate_output, leaf_expert_ids,gate_logits = model(data, prompt_flag)
                v_loss,mse_loss, = criterion(outputs, label, gate_logits,mode='predict')
                for idx in feature_idx.unique():
                        valid_dataset.update_feature_loss(idx.item(), mse_loss.item())
                avg_valid_loss += v_loss.item()
                avg_valid_mse_loss += mse_loss.item()

        avg_valid_loss /= len(valid_dataloader)
        avg_valid_mse_loss /= len(valid_dataloader)
        valid_loss_arr.append(avg_valid_loss)
        valid_mse_loss_arr.append(avg_valid_mse_loss)
        epoch_time = time.time() - start_time
        current_lr = scheduler.get_last_lr()[0]

        print(f"Epoch {epoch + 1}, Time: {epoch_time:.2f}s, Training Loss: {avg_train_loss},Training MSE_Loss: {avg_train_mse_loss}, Validation Loss: {avg_valid_loss}, Validation MSE_Loss: {avg_valid_mse_loss},Learning Rate: {current_lr}")
        scheduler.step(avg_valid_loss)

        if avg_valid_loss < best_loss:
            best_loss = avg_valid_loss
            patience_counter = 0  
            torch.save(model, model_save_path)  
            logging.info(f"Saved model with Validation Loss: {best_loss}")


        if avg_valid_loss >= best_loss:
            patience_counter += 1
            if patience_counter >= patience:
                logging.info("Early stopping triggered. Stopping training...")
                break

    plt.figure(figsize=(10, 5))
    plt.plot(train_loss_arr, label='Training Loss')
    plt.plot(valid_loss_arr, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Curves')
    plt.legend()
    plt.savefig(os.path.join(expert_save_path, "loss_curves.png"))
    plt.show()

    return train_loss_arr, valid_loss_arr

def test(test_dataset, test_dataloader, criterion, device,  run_dir):

    criterion = nn.MSELoss()
    model = torch.load(r'pth/predict_base_model/20250324_212603/zero-shot-without-Parking.pth',map_location=device)
    model.eval()
    mask = np.load(r'/mask_128_64.npy')
    mask_tensor = torch.tensor(mask, dtype=torch.bool).to(device)
    all_predictions = []
    all_labels = []
    all_gate = []
    total_mse = 0.0
    total_mae = 0.0
    total_samples = 0
    with torch.no_grad():
        for data, label,feature_idx in test_dataloader:
            data = data.to(device)
            output, gate_output, leaf_expert_ids,_ = model(data,  prompt_flag=1)
            feature_idx = feature_idx[0]
            output = test_dataset.data_denormalization(output,feature_idx)
            label = test_dataset.data_denormalization(label,feature_idx)

            label = label.to(device)
            mask_t = mask_tensor.unsqueeze(0).unsqueeze(-1).expand_as(label).view(label.size(0), -1)
            output_flat = output.view(output.size(0), -1)
            label_flat =label.view(label.size(0), -1)
            mse = torch.mean((output_flat - label_flat) ** 2, dim=1)  
            mae = torch.mean(torch.abs(output_flat - label_flat), dim=1)  
            all_gate.append(gate_output.cpu())

            total_mse += torch.sum(mse).item()
            total_mae += torch.sum(mae).item()
            true_value = label[0, :, :, 0].cpu().numpy()
            outputs_cpu = output[0, :, :, 0].cpu().numpy()
            gate_output_cpu = gate_output[0, :, :].cpu().numpy()
            total_samples += len(label)
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
            vmin = np.min(true_value)
            vmax = np.max(true_value)
            im1 = ax1.imshow((true_value*mask).transpose(), cmap='hot', interpolation='nearest', origin='lower', vmin=vmin, vmax=vmax)
            ax1.set_title('True Value')
            ax1.axis('off')
            im2 = ax2.imshow((outputs_cpu*mask).transpose(), cmap='hot', interpolation='nearest', origin='lower', vmin=vmin, vmax=vmax)
            ax2.set_title('Predictions')
            ax2.axis('off')
            cbar = fig.colorbar(im1, ax=[ax1, ax2])
            cbar.set_label('Value')
            plt.savefig(os.path.join(run_dir, f"test_prediction_{data.shape[0]}.png"))
            plt.close(fig)
    avg_mse = total_mse / total_samples
    avg_mae = total_mae / total_samples

    print(f'Overall Test MSE: {avg_mse} MAE:{avg_mae}')
    return avg_mse, avg_mse

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    input_data = np.load(r'/nyc_data.npy')
    input_data = np.clip(input_data, 0, None)
    input_data = torch.from_numpy(input_data).to(device).float()
    prompt_dim = 128
    kernel_size = 7
    window_size =5
    dropout = 0.1
    num_epochs = 1000
    r = 0
    batch_size = 16
    lr = 0.00001
    lambda_mse = 10
    lambda_load_balance = 0.01 
    lambda_smooth = 0.001
    train_size = int(0.6 * len(input_data))
    valid_size = int(0.2 * len(input_data))
    test_size = len(input_data) - train_size - valid_size

    
    x_train_data = input_data[:train_size]
    x_valid_data = input_data[train_size:train_size + valid_size]
    x_test_data = input_data[train_size + valid_size:]
    #x_test_data = test_data

    train_dataset = preDataset(x_train_data)
    valid_dataset = preDataset(x_valid_data)
    test_dataset = preDataset(x_test_data)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=train_dataset.collate_fn)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,collate_fn=valid_dataset.collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,collate_fn=test_dataset.collate_fn)

    # load pre-trained model 
    pretrained_model_save_path = r'/model_structure.pth'
    pretrained_model = torch.load(pretrained_model_save_path,map_location=device)
    pretrained_params = torch.load(r'/model_param.pth', 
                                   map_location=device)
    pretrained_model.load_state_dict(pretrained_params)


    for name, param in pretrained_model.moe_layer.named_parameters():
        if 'gate' or 'expert'in name:
            param.requires_grad = True  
        else:
            param.requires_grad = False 

    for name, param in pretrained_model.encoder.named_parameters():
        if "attn" not in name and "norm" not in name:
            param.requires_grad = False

    st_embed_input_dim = 64 
    encoder_dim = 128  

    downstream_model = STpredictor(pretrained_model,  st_embed_input_dim, encoder_dim,   prompt_dim, kernel_size,device)
    downstream_model = downstream_model.to(device)
    criterion = DownStream_DistanceLoss()
    optimizer = optim.Adam(downstream_model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08)

    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    base_dir = r'pth/predict_base_model/nyc/'
    run_dir = os.path.join(base_dir, current_time)
    os.makedirs(run_dir, exist_ok=True)
    log_file = os.path.join(run_dir, f"{current_time}_log.txt")
    configure_logging(log_file)
    model_save_path = os.path.join(run_dir, "one-for-all.pth") 
    expert_save_path = os.path.join(run_dir, "experts")    
    os.makedirs(expert_save_path, exist_ok=True)
    config_file = os.path.join(run_dir, f"{current_time}_config.json")    

    config = {
        "current_time": current_time,
        "prompt_dim": prompt_dim,
        "kernel_size": kernel_size,  
        "dropout": dropout,
        "num_epochs": num_epochs,
        "r": r,
        "batch_size": batch_size,
        "lr": lr,
        "train_size": train_size,
        "valid_size": valid_size,
        "test_size": test_size,
        "device": str(device)
    }
    with open(config_file, 'w') as f:
        json.dump(config, f, indent=4)
    logging.info(f"Configuration saved to {config_file}")
    
    train(downstream_model, train_dataset, train_dataloader, valid_dataset,valid_dataloader, criterion, optimizer, num_epochs, device, log_file, expert_save_path, model_save_path)

if __name__ == '__main__':
    main()