# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import torch
from model import *
from utils import *
import torch.optim as optim
from Dataset import *
from predictor import *
from tqdm import tqdm
import time
import matplotlib.colors as mcolors
import os
import torch.optim.lr_scheduler as lr_scheduler
import logging
import json
import datetime


def freeze_pretrained_model(model):


    for name, param in model.pretrained_model.moe_layer.named_parameters():
        if 'parent_gate' in name:
            param.requires_grad = True 
        else:
            param.requires_grad = False  


    for name, param in model.pretrained_model.encoder.named_parameters():
        if "attn" not in name and "norm" not in name:
            param.requires_grad = False

def configure_logging(log_file):
    logging.basicConfig(
                    level=logging.DEBUG,  
                    format='%(asctime)s - %(levelname)s - %(message)s', 
                    filename='/load_balancing.log', 
                    filemode='w'  
                )
    logging.info("Logging started.")


def train(model, train_dataset, train_dataloader, valid_dataset, valid_dataloader,criterion, optimizer, num_epochs, device, log_file, expert_save_path, model_save_path):
    train_loss_arr = []
    valid_loss_arr = []
    train_mse_loss_arr = []
    valid_mse_loss_arr = []
    best_loss = np.inf
    patience_counter = 0
    patience = 200

    configure_logging(log_file)
    prompt_flag = 1
   
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, min_lr=1e-5)
    for epoch in range(num_epochs):
        start_time = time.time()  
        model.train()
        avg_train_loss = 0.0
        avg_train_mse_loss = 0.0

        for data, label ,feature_idx in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            data = data.to(device)
            label = label.to(device)
            optimizer.zero_grad()
            output, gate_output, leaf_expert_ids,gate_logits = model(data, prompt_flag)
            gate_output = gate_output.to(device)
            train_loss, mse_loss = criterion(output, label, gate_logits,mode='predict')
            train_loss.backward()
            optimizer.step()
            avg_train_loss += train_loss.item()
            avg_train_mse_loss += mse_loss.item()
        avg_train_loss /= len(train_dataloader)
        avg_train_mse_loss /= len(train_dataloader)
        train_loss_arr.append(avg_train_loss)
        train_mse_loss_arr.append(avg_train_mse_loss)


        model.eval()
        avg_valid_loss = 0.0
        avg_valid_mse_loss = 0.0
        with torch.no_grad():
            for data, label,feature_idx in valid_dataloader:
                data = data.to(device)
                label = label.to(device)
                outputs, gate_output, leaf_expert_ids,gate_logits = model(data, prompt_flag)
               

                gate_output = gate_output.to(device)
                v_loss,mse_loss, = criterion(outputs, label, gate_logits,mode='predict')
                avg_valid_loss += v_loss.item()
                avg_valid_mse_loss += mse_loss.item()

        avg_valid_loss /= len(valid_dataloader)
        avg_valid_mse_loss /= len(valid_dataloader)
        valid_loss_arr.append(avg_valid_loss)
        valid_mse_loss_arr.append(avg_valid_mse_loss)

        epoch_time = time.time() - start_time
        current_lr = scheduler.get_last_lr()[0]
        
        print(f"Epoch {epoch + 1}, Time: {epoch_time:.2f}s, Training Loss: {avg_train_loss},Training MSE_Loss: {avg_train_mse_loss}, Validation Loss: {avg_valid_loss}, Validation MSE_Loss: {avg_valid_mse_loss},Learning Rate: {current_lr}")

        scheduler.step(avg_valid_loss)

        if avg_valid_loss < best_loss:
            best_loss = avg_valid_loss
            patience_counter = 0  
            torch.save(model, model_save_path)  
            logging.info(f"Saved model with Validation Loss: {best_loss}")
        if avg_valid_loss >= best_loss:
            patience_counter += 1
            if patience_counter >= patience:
                logging.info("Early stopping triggered. Stopping training...")
                break

    plt.figure(figsize=(10, 5))
    plt.plot(train_loss_arr, label='Training Loss')
    plt.plot(valid_loss_arr, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Curves')
    plt.legend()
    plt.savefig(os.path.join(expert_save_path, "loss_curves.png"))
    plt.show()

    return train_loss_arr, valid_loss_arr


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    input_data = np.load(r'/Chicago/data.npy')[...,1]
    input_data = torch.from_numpy(input_data).to(device).float().unsqueeze(-1)
    tune_rate = 0.2

    num_days = 7
    model_dim = 128
    dim_feedforward = 128
    prompt_dim = 128
    kernel_size = 7

    dropout = 0.1
    num_epochs = 2000
    r = 0
    batch_size = 16
    lr = 0.000001

    train_size = int(0.6 * len(input_data)*tune_rate)
    valid_size = int(0.2 * len(input_data)*tune_rate)
    test_size = len(input_data)*tune_rate - train_size - valid_size

    
    x_train_data = input_data[:train_size]
    x_valid_data = input_data[train_size:train_size + valid_size]
    x_test_data = input_data[train_size + valid_size:]


    train_dataset = preDataset(x_train_data)
    valid_dataset = preDataset(x_valid_data)
    test_dataset = preDataset(x_test_data)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=train_dataset.collate_fn)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,collate_fn=valid_dataset.collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,collate_fn=test_dataset.collate_fn)


    st_embed_input_dim = 32  
    encoder_dim = 128  
    predict_base_mode = 'pth/predict_base_model/chicago/20250419_190458/without_1.pth'

    downstream_model = torch.load(predict_base_mode,map_location=device)

    downstream_model = downstream_model.to(device)

    freeze_pretrained_model(downstream_model)

    criterion =  DownStream_DistanceLoss()
    optimizer = optim.Adam(downstream_model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08)

    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    base_dir = r'pth/tune_model/chicago'
    run_dir = os.path.join(base_dir, current_time)
    os.makedirs(run_dir, exist_ok=True)
    log_file = os.path.join(run_dir, f"{current_time}_log.txt")
    configure_logging(log_file)
    model_save_path = os.path.join(run_dir, "1_20%.pth") 
    expert_save_path = os.path.join(run_dir, "experts")    
    os.makedirs(expert_save_path, exist_ok=True)
    config_file = os.path.join(run_dir, f"{current_time}_config.json")   

    config = {
        "current_time": current_time,
        "num_days": num_days,
        "model_dim": model_dim,
        "dim_feedforward": dim_feedforward,
        "prompt_dim": prompt_dim,
        "kernel_size": kernel_size,  
        "dropout": dropout,
        "num_epochs": num_epochs,
        "r": r,
        "batch_size": batch_size,
        "lr": lr,
        "train_size": train_size,
        "valid_size": valid_size,
        "device": str(device)
    }
    with open(config_file, 'w') as f:
        json.dump(config, f, indent=4)
    logging.info(f"Configuration saved to {config_file}")

    train(downstream_model, train_dataset, train_dataloader, valid_dataset,valid_dataloader, criterion, optimizer, num_epochs, device, log_file, expert_save_path, model_save_path)

if __name__ == '__main__':
    main()