import os
import sys
sys.path.append("..")
import argparse
import numpy as np
import math
from random import SystemRandom
import matplotlib.pyplot as plt

# Argument parser for command-line options
parser = argparse.ArgumentParser('ITS Forecasting')
parser.add_argument('--state', type=str, default='def')
parser.add_argument('-n',  type=int, default=int(1e8), help="Size of the dataset")
parser.add_argument('--epoch', type=int, default=1000, help="training epoches")
parser.add_argument('--patience', type=int, default=10, help="patience for early stop")
parser.add_argument('--history', type=int, default=24, help="number of hours (months for ushcn and ms for activity) as historical window")
parser.add_argument('--pred_window', type=int, default=1, help="number of hours (months for ushcn) as pred window")
parser.add_argument('--logmode', type=str, default="a", help='File mode of logging.')
parser.add_argument('--lr',  type=float, default=1e-3, help="Starting learning rate.")
parser.add_argument('--w_decay', type=float, default=0.0, help="weight decay.")
parser.add_argument('-b', '--batch_size', type=int, default=32)
parser.add_argument('--load', type=str, default=None, help="ID of the experiment to load for evaluation. If None, run a new experiment.")
parser.add_argument('--seed', type=int, default=1, help="Random seed")
parser.add_argument('--dataset', type=str, default='physionet', help="Dataset to load. Available: physionet, mimic, ushcn, activity")
parser.add_argument('--quantization', type=float, default=0.0, help="Quantization on the physionet dataset.")
parser.add_argument('--nhead', type=int, default=1, help="heads in Transformer")
parser.add_argument('--nlayer', type=int, default=1, help="# of layer in TSmodel")
parser.add_argument('-ps', '--patch_size', type=float, default=24, help="window size for a patch")
parser.add_argument('--stride', type=float, default=24, help="period stride for patch sliding")
parser.add_argument('-hd', '--hid_dim', type=int, default=64, help="Hidden dim of node embeddings")
parser.add_argument('--alpha', type=float, default=1, help="Proportion of Time decay")
parser.add_argument('--res', type=float, default=1, help="Res")
parser.add_argument('--gpu', type=str, default='0', help='which gpu to use.')
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--model', type=str)


args = parser.parse_args()
args.npatch = int(np.ceil((args.history - args.patch_size) / args.stride)) + 1 # (window size for a patch)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import torch
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic=True
torch.use_deterministic_algorithms(True)
import utils as utils
from parse_datasets import parse_datasets
from evaluation import *
from baselines.models.RTI import *
import warnings
warnings.filterwarnings("ignore")

file_name = os.path.basename(__file__)[:-3]
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.PID = os.getpid()
print("PID, device:", args.PID, args.device)
#####################################################################################################
def layer_of_patches(n_patch):
    if n_patch == 1:
        return 1
    if n_patch % 2 == 0:
        return 1 + layer_of_patches(n_patch / 2)
    else:
        return layer_of_patches(n_patch + 1)


def visualize_forecast(args, vals, quite_vals=None, sample_idx=0, batch_idx=0, feature_idx=0, save_path="forecast_plot.png"):
    """
    vals: (history_v, history_t, preds, trues, times) 튜플
    sample_idx: 시각화할 배치의 샘플 인덱스
    batch_idx: 시각화할 배치 인덱스
    feature_idx: 시각화할 변수(컬럼)의 인덱스 (데이터가 다차원일 경우)
    """
    if not os.path.exists(f"forecast/{args.dataset}/{args.history}history_{args.pred_window}pred/"):
        os.makedirs(f"forecast/{args.dataset}/{args.history}history_{args.pred_window}pred/", exist_ok=True)
        
    hist_v, hist_t, preds, trues, times = vals
    if quite_vals is not None:
        quite_preds = quite_vals

    # 1. 데이터 추출 및 차원 정리
    # [batch, seq, dim] -> [seq]
    h_v = hist_v[batch_idx][sample_idx, :, feature_idx].detach().cpu().numpy()
    h_t = hist_t[batch_idx][sample_idx, :, feature_idx].detach().cpu().numpy()
    y_true = trues[batch_idx][sample_idx, :, feature_idx].detach().cpu().numpy()
    y_pred = preds[batch_idx][sample_idx, :, feature_idx].detach().cpu().numpy()
    if quite_vals is not None:
        quite_pred = quite_preds[batch_idx][sample_idx, :, feature_idx].detach().cpu().numpy()
    t_future = times[batch_idx][sample_idx, :, feature_idx].detach().cpu().numpy()

    # 2. 마스크 처리 (0이 아닌 데이터만 추출)
    # 과거 데이터 필터링
    valid_hist = np.where(h_t > 0)[0]
    h_t, h_v = h_t[valid_hist], h_v[valid_hist]
    valid_fut = np.where(t_future > 0)[0]
    t_fut, y_true, y_pred = t_future[valid_fut], y_true[valid_fut], y_pred[valid_fut]
    if quite_vals is not None:
        quite_pred = quite_pred[valid_fut]

    # 3. 그래프 그리기
    plt.figure(figsize=(12, 5))
    
    # --- 과거 데이터 + 실제 미래 데이터 (하나의 선으로 연결) ---
    # 이미지처럼 자연스럽게 잇기 위해 끝점과 시작점을 연결합니다.
    full_t = np.concatenate([h_t, t_fut])
    full_y = np.concatenate([h_v, y_true])
    
    plt.plot(full_t, full_y, marker='o', linestyle='-', color='#2c7bb6', label='groundtruth', markersize=5, linewidth=1.5)
    
    # --- 예측 데이터 (과거의 마지막 시점에서 시작하게 설정) ---
    # 예측의 첫 시작점을 과거의 마지막 데이터와 연결하면 더 보기 좋습니다.
    pred_t = t_fut
    pred_y = y_pred
    
    plt.plot(pred_t, pred_y, marker='o', linestyle='-', color='#fdae61', label=str(args.model), markersize=5, linewidth=1.5)
    if quite_vals is not None:
        plt.plot(pred_t, quite_pred, marker='o', linestyle='-', color="#61fd88", label="Quite + "+str(args.model), markersize=5, linewidth=1.5)

    # 4. 스타일링 및 구분선
    # 과거와 미래의 경계선 (과거의 마지막 시간대)
    plt.axvline(x=pred_t[0], color='gray', linestyle='--', linewidth=1.2, alpha=0.8)
    
    plt.xlabel('Time', fontsize=10)
    plt.ylabel('Value', fontsize=10)
    plt.title(f'Forecasting Visualization (Sample {sample_idx}, Feature {feature_idx})')
    plt.legend(loc='upper left')
    plt.grid(axis='y', linestyle=':', alpha=0.5) # 가로 그리드만 살짝 추가
    
    plt.tight_layout()
    plt.savefig(f"forecast/{args.dataset}/{args.history}history_{args.pred_window}pred/{save_path}")
    plt.show()
    print(f"결과가 forecast/{args.dataset}/{args.history}history_{args.pred_window}pred/{save_path}에 저장되었습니다.")
    

if __name__ == '__main__':
    print(args)
    utils.setup_seed(args.seed)
    
    experimentID = args.load
    if experimentID is None:
        experimentID = int(SystemRandom().random()*100000)

    input_command = sys.argv
    ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind+2):]
    input_command = " ".join(input_command)

    # Parse dataset and initialize model
    args.mode = "False"
    args.irr_emb = False
    args.hid_dim = 256
    if args.model == 'itransformer' or args.model == 's_mamba':
        data_obj = parse_datasets(args, patch_ts=False, max_ts=True)
    elif args.model == 'patchtst' or args.model == 'timexer' or args.model == 'patchmixer' or args.model == 'tsmixer':
        data_obj = parse_datasets(args, patch_ts=True, max_ts=True)  

    ### Model setting ###
    args.ndim = data_obj["input_dim"]
    args.npatch = int(math.ceil((args.history - args.patch_size) / args.stride)) + 1
    args.patch_layer = layer_of_patches(args.npatch)
    args.scale_patch_size = args.patch_size / (args.history + args.pred_window)
    args.task = 'forecasting'
    
    # model
    model = Model(args).to(args.device)
    model.load_state_dict(torch.load(f"./models/{args.dataset}_{args.history}history_{args.pred_window}pred_{args.model}_{args.mode}.pt"))
    model.eval()
    with torch.no_grad():
        vals = forecast(model, data_obj["val_dataloader"], data_obj["n_val_batches"], ts=True)
     
    
    # Parse dataset and initialize model
    args.mode = "self"
    args.irr_emb = True
    args.hid_dim = 64
    data_obj = parse_datasets(args, patch_ts=True, max_ts=False)
       
    # quite + model
    model = Model(args).to(args.device)
    model.load_state_dict(torch.load(f"./models/{args.dataset}_{args.history}history_{args.pred_window}pred_{args.model}_{args.mode}.pt"))
    model.eval()
    with torch.no_grad():
        quite_vals = forecast(model, data_obj["val_dataloader"], data_obj["n_val_batches"])
    
    for n in range(args.batch_size):
        for i in range(args.ndim):
            try:
                visualize_forecast(args, vals, quite_vals, batch_idx=n, feature_idx=i, save_path=f's{n}_f{i}_{args.model}.png')
                continue
            except:
                print("error")