import torch
import torch.nn as nn
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from .ROOTPATH import ROOTPATH


class ChronosWrapper(nn.Module):
	def __init__(self, configs):
		super().__init__()

		self.seq_len = configs.seq_len
		self.label_len = getattr(configs, 'label_len', 0)
		self.pred_len = configs.pred_len
		self.enc_in = configs.enc_in

		self.context_length = getattr(configs, 'context_length', self.seq_len)
		self.prediction_length = getattr(configs, 'prediction_length', self.pred_len)
		self.num_samples = getattr(configs, 'num_samples', 1)
		self.agg_method = getattr(configs, 'agg_method', 'mean')
		self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
		self.device_map = getattr(configs, 'device_map', self.device)
		self.torch_dtype = getattr(configs, 'torch_dtype', 'bfloat16')
		self.model_path = getattr(configs, 'chronos_model_path', f'{ROOTPATH}/chronos-t5-small')
		self.chronos_force_cpu_input = getattr(configs, 'chronos_force_cpu_input', True)

		self.enable_plot = getattr(configs, 'enable_plot', True)
		self.plot_dir = getattr(configs, 'plot_dir', './plots_chronos')
		self.plot_counter = 0
		if self.enable_plot:
			os.makedirs(self.plot_dir, exist_ok=True)

		self.pipeline = None

	def _initialize_model(self):
		if self.pipeline is not None:
			return
		try:
			from chronos import ChronosPipeline

			dtype_map = {
				'bfloat16': torch.bfloat16,
				'float16': torch.float16,
				'float32': torch.float32,
				'fp32': torch.float32
			}
			torch_dtype = dtype_map.get(str(self.torch_dtype).lower(), torch.bfloat16)

			self.pipeline = ChronosPipeline.from_pretrained(
				self.model_path,
				device_map=self.device_map,
				torch_dtype=torch_dtype,
			)
			if hasattr(self.pipeline, 'model'):
				self.inner_model = self.pipeline.model
			
			self.model_device = None
			for attr_name in [
				'model', 'inner_model', 'base_model']:
				if hasattr(self.pipeline, attr_name):
					try:
						self.model_device = next(getattr(self.pipeline, attr_name).parameters()).device
						break
					except Exception:
						pass
			if self.model_device is None:
				self.model_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
			print(f"Successfully loaded Chronos model from {self.model_path}")
			print(f"Device map: {self.device_map}, dtype: {torch_dtype}, resolved model device: {self.model_device}")
		except Exception as e:
			print(f"Failed to load Chronos model: {e}")
			raise e

	def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
		if self.pipeline is None:
			self._initialize_model()

		preds = self._forward_pretrained(x_enc)

		if self.enable_plot:
			try:
				self._plot_prediction(x_enc, preds)
			except Exception as e:
				print(f"Plot failed: {e}")
		return preds

	def _forward_pretrained(self, x_enc: torch.Tensor) -> torch.Tensor:
		B, L, D = x_enc.shape

		if L > self.context_length:
			x_ctx = x_enc[:, -self.context_length:, :]
			L_eff = self.context_length
		else:
			if L < self.context_length:
				pad_len = self.context_length - L
				last_vals = x_enc[:, -1:, :].repeat(1, pad_len, 1)
				x_ctx = torch.cat([x_enc, last_vals], dim=1)
			else:
				x_ctx = x_enc
			L_eff = self.context_length

		series_batch = x_ctx.permute(0, 2, 1).reshape(B * D, L_eff)

		model_dtype = None
		try:
			if hasattr(self.pipeline, 'model'):
				model_dtype = next(self.pipeline.model.parameters()).dtype
		except Exception:
			pass
		if model_dtype is None:
			model_dtype = torch.float32

		def _prepare_input(cpu: bool):
			if cpu:
				return series_batch.to(device='cpu', dtype=model_dtype)
			else:
				return series_batch.to(device=getattr(self, 'model_device', 'cpu'), dtype=model_dtype)

		def _run_predict(tensor_input):
			with torch.no_grad():
				try:
					return self.pipeline.predict(
						tensor_input,
						self.prediction_length,
						self.num_samples if self.num_samples > 1 else None
					)
				except TypeError:
					return self.pipeline.predict(tensor_input, self.prediction_length)

		forecast = None
		first_try_cpu = self.chronos_force_cpu_input or (not torch.cuda.is_available())
		try:
			context_tensor = _prepare_input(first_try_cpu)
			forecast = _run_predict(context_tensor)
		except RuntimeError as re:
			msg = str(re)
			if 'Expected all tensors to be on the same device' in msg and first_try_cpu is False:
				context_tensor = _prepare_input(True)
				forecast = _run_predict(context_tensor)
			elif 'Expected all tensors to be on the same device' in msg and first_try_cpu is True:
				if torch.cuda.is_available():
					context_tensor = _prepare_input(False)
					forecast = _run_predict(context_tensor)
				else:
					raise
			else:
				raise

		if isinstance(forecast, np.ndarray):
			forecast_t = torch.from_numpy(forecast)
		elif isinstance(forecast, torch.Tensor):
			forecast_t = forecast
		else:
			forecast_t = torch.tensor(forecast)

		if forecast_t.dim() == 2:
			forecast_t = forecast_t.unsqueeze(1)  # [num_series, 1, pred_len]
		elif forecast_t.dim() != 3:
			raise RuntimeError(f"Unexpected forecast tensor shape: {forecast_t.shape}")

		if self.agg_method == 'median':
			point = forecast_t.median(dim=1).values  # [num_series, pred_len]
		else:
			point = forecast_t.mean(dim=1)  # [num_series, pred_len]

		out = point.reshape(B, D, self.prediction_length).permute(0, 2, 1).contiguous()
		return out.to(x_enc.device, dtype=x_enc.dtype)

	def _plot_prediction(self, x_enc: torch.Tensor, preds: torch.Tensor):
		B, L, D = x_enc.shape
		pred_len = preds.shape[1]
		if B == 0:
			return
		max_feats = min(D, 4)
		fig, axes = plt.subplots(max_feats, 1, figsize=(12, 3 * max_feats))
		if max_feats == 1:
			axes = [axes]
		for i in range(max_feats):
			ax = axes[i]
			hist = x_enc[0, :, i].detach().cpu().numpy()
			fut = preds[0, :, i].detach().cpu().numpy()
			t_hist = np.arange(len(hist))
			t_fut = np.arange(len(hist), len(hist) + pred_len)
			ax.plot(t_hist, hist, 'b-', label='History', linewidth=2)
			ax.plot(t_fut, fut, 'r-', label='Chronos Pred', linewidth=2)
			ax.plot([len(hist)-1, len(hist)], [hist[-1], fut[0]], 'g--', alpha=0.6)
			ax.axvline(x=len(hist), color='gray', linestyle='--', alpha=0.5)
			ax.set_title(f'Feature {i+1}')
			ax.grid(alpha=0.3)
			if i == 0:
				ax.legend()
		plt.tight_layout()
		fname = f'chronos_prediction_{self.plot_counter:04d}.png'
		fpath = os.path.join(self.plot_dir, fname)
		plt.savefig(fpath, dpi=150, bbox_inches='tight')
		plt.close()
		self.plot_counter += 1
		print(f"Chronos prediction plot saved to: {fpath}")


class Model(ChronosWrapper):
	pass
