import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from .ROOTPATH import ROOTPATH
from visionts import VisionTS



class VisionTSWrapper(nn.Module):
	def __init__(self, configs):
		super().__init__()

		self.seq_len = configs.seq_len
		self.label_len = getattr(configs, 'label_len', 0)
		self.pred_len = configs.pred_len
		self.enc_in = configs.enc_in

		self.arch = getattr(configs, 'visionts_arch', 'mae_base')
		self.finetune_type = getattr(configs, 'visionts_finetune_type', 'ln')
		self.ckpt_dir = getattr(configs, 'visionts_ckpt_dir', f'{ROOTPATH}/VisionTS')
		self.load_ckpt = getattr(configs, 'visionts_load_ckpt', True)
		self.periodicity = getattr(configs, 'visionts_periodicity', 1)
		self.norm_const = getattr(configs, 'visionts_norm_const', 0.4)
		self.align_const = getattr(configs, 'visionts_align_const', 0.4)
		self.interpolation = getattr(configs, 'visionts_interpolation', 'bilinear')
		self.use_fp64 = getattr(configs, 'visionts_fp64', False)
		self.export_image = getattr(configs, 'visionts_export_image', False)

		self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
		self.enable_plot = getattr(configs, 'enable_plot', True)
		self.plot_dir = getattr(configs, 'plot_dir', './plots_visionts')
		self.plot_counter = 0
		if self.enable_plot:
			os.makedirs(self.plot_dir, exist_ok=True)

		self.model = None

	def _initialize_model(self):
		if self.model is not None:
			return

		self.model = VisionTS(
			arch=self.arch,
			finetune_type=self.finetune_type,
			ckpt_dir=self.ckpt_dir,
			load_ckpt=self.load_ckpt,
		)
		self.model.update_config(
			context_len=self.seq_len,
			pred_len=self.pred_len,
			periodicity=self.periodicity,
			norm_const=self.norm_const,
			align_const=self.align_const,
			interpolation=self.interpolation,
		)
		self.model.to(self.device)
		self.model.eval()

	def _maybe_update_config(self, context_len: int):
		if getattr(self.model, 'context_len', None) != context_len:
			self.model.update_config(
				context_len=context_len,
				pred_len=self.pred_len,
				periodicity=self.periodicity,
				norm_const=self.norm_const,
				align_const=self.align_const,
				interpolation=self.interpolation,
			)

	def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
		if self.model is None:
			self._initialize_model()

		context_len = x_enc.shape[1]
		self._maybe_update_config(context_len)

		x_in = x_enc.to(self.device)
		with torch.no_grad():
			preds = self.model(x_in, export_image=self.export_image, fp64=self.use_fp64)

		if preds is None:
			return preds

		preds = preds.to(device=x_enc.device, dtype=x_enc.dtype)

		if self.enable_plot:
			try:
				self._plot_prediction(x_enc, preds)
			except Exception as exc:
				print(f"VisionTS plot failed: {exc}")

		return preds

	def _plot_prediction(self, x_enc, preds):
		batch_size, seq_len, feature_dim = x_enc.shape
		if batch_size == 0:
			return

		max_feats = min(feature_dim, 4)
		fig, axes = plt.subplots(max_feats, 1, figsize=(12, 3 * max_feats))
		if max_feats == 1:
			axes = [axes]

		pred_len = preds.shape[1]
		for i in range(max_feats):
			ax = axes[i]
			hist = x_enc[0, :, i].detach().cpu().numpy()
			fut = preds[0, :, i].detach().cpu().numpy()
			t_hist = np.arange(len(hist))
			t_fut = np.arange(len(hist), len(hist) + pred_len)
			ax.plot(t_hist, hist, 'b-', label='History', linewidth=2)
			ax.plot(t_fut, fut, 'r-', label='VisionTS Pred', linewidth=2)
			ax.plot([len(hist) - 1, len(hist)], [hist[-1], fut[0]], 'g--', alpha=0.6)
			ax.axvline(x=len(hist), color='gray', linestyle='--', alpha=0.5)
			ax.set_title(f'Feature {i + 1}')
			ax.grid(alpha=0.3)
			if i == 0:
				ax.legend()

		plt.tight_layout()
		fname = f'visionts_prediction_{self.plot_counter:04d}.png'
		fpath = os.path.join(self.plot_dir, fname)
		plt.savefig(fpath, dpi=150, bbox_inches='tight')
		plt.close()
		self.plot_counter += 1
		print(f"VisionTS prediction plot saved to: {fpath}")


class Model(VisionTSWrapper):
	pass
