import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from .ROOTPATH import ROOTPATH

from toto.model.toto import Toto
from toto.data.util.dataset import MaskedTimeseries
from toto.inference.forecaster import TotoForecaster



class TotoWrapper(nn.Module):
	def __init__(self, configs):
		super().__init__()

		self.seq_len = configs.seq_len
		self.label_len = getattr(configs, 'label_len', 0)
		self.pred_len = configs.pred_len
		self.enc_in = configs.enc_in

		self.model_id = getattr(configs, 'toto_model_id', 'Datadog/Toto-Open-Base-1.0')
		self.local_model_dir = getattr(configs, 'toto_local_dir', f'{ROOTPATH}/Toto-Open-Base-1.0')
		self.device = getattr(configs, 'device', 'cuda' if torch.cuda.is_available() else 'cpu')
		self.num_samples = getattr(configs, 'toto_num_samples', 100)
		self.samples_per_batch = getattr(configs, 'toto_samples_per_batch', 10)
		self.time_interval_seconds = getattr(configs, 'toto_time_interval_seconds', 60 * 15)
		self.use_compile = getattr(configs, 'toto_compile', False)

		self.enable_plot = getattr(configs, 'enable_plot', True)
		self.plot_dir = getattr(configs, 'plot_dir', './plots_toto')
		self.plot_counter = 0
		if self.enable_plot:
			os.makedirs(self.plot_dir, exist_ok=True)

		self.toto = None
		self.forecaster = None

	def _initialize_model(self):
		if self.toto is not None:
			return

		model_source = self.local_model_dir if os.path.exists(self.local_model_dir) else self.model_id
		self.toto = Toto.from_pretrained(model_source).to(self.device)
		if self.use_compile and hasattr(self.toto, 'compile'):
			try:
				self.toto.compile()
			except Exception as exc:
		self.forecaster = TotoForecaster(self.toto.model)

	def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
		if self.toto is None:
			self._initialize_model()

		B, L, D = x_enc.shape
		series = x_enc.permute(0, 2, 1).reshape(B * D, L).to(self.device)
		timestamp_seconds = torch.zeros_like(series)
		time_interval_seconds = torch.full((B * D,), self.time_interval_seconds, device=self.device)
		inputs = MaskedTimeseries(
			series=series,
			padding_mask=torch.full_like(series, True, dtype=torch.bool),
			id_mask=torch.zeros_like(series),
			timestamp_seconds=timestamp_seconds,
			time_interval_seconds=time_interval_seconds,
		)

		with torch.no_grad():
			forecast = self.forecaster.forecast(
				inputs,
				prediction_length=self.pred_len,
				num_samples=self.num_samples,
				samples_per_batch=self.samples_per_batch,
			)
			pred = forecast.median  # [B*D, pred_len]

		pred_tensor = pred.reshape(B, D, self.pred_len).permute(0, 2, 1).to(x_enc.device, dtype=x_enc.dtype)

		if self.enable_plot:
			try:
				self._plot_prediction(x_enc, pred_tensor)
			except Exception as exc:
				print(f"Toto plot failed: {exc}")

		return pred_tensor

	def _plot_prediction(self, x_enc, preds):
		B, L, D = x_enc.shape
		if B == 0:
			return
		max_feats = min(D, 4)
		fig, axes = plt.subplots(max_feats, 1, figsize=(12, 3 * max_feats))
		if max_feats == 1:
			axes = [axes]

		pred_len = preds.shape[1]
		for i in range(max_feats):
			ax = axes[i]
			hist = x_enc[0, :, i].detach().cpu().numpy()
			fut = preds[0, :, i].detach().cpu().numpy()
			t_hist = np.arange(len(hist))
			t_fut = np.arange(len(hist), len(hist) + pred_len)
			ax.plot(t_hist, hist, 'b-', label='History', linewidth=2)
			ax.plot(t_fut, fut, 'r-', label='Toto Prediction (Median)', linewidth=2)
			ax.plot([len(hist) - 1, len(hist)], [hist[-1], fut[0]], 'g--', alpha=0.6)
			ax.axvline(x=len(hist), color='gray', linestyle='--', alpha=0.5)
			ax.set_title(f'Feature {i + 1}')
			ax.grid(alpha=0.3)
			if i == 0:
				ax.legend()

		plt.tight_layout()
		fname = f'toto_prediction_{self.plot_counter:04d}.png'
		fpath = os.path.join(self.plot_dir, fname)
		plt.savefig(fpath, dpi=150, bbox_inches='tight')
		plt.close()
		self.plot_counter += 1


class Model(TotoWrapper):
	pass
