import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Dict, Any, List
import sys
import os
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from .ROOTPATH import ROOTPATH


class TimesFMWrapper(nn.Module):
    def __init__(self, configs):
        super(TimesFMWrapper, self).__init__()
        
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len  
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        
        self.context_length = getattr(configs, 'context_length', min(512, self.seq_len))
        self.prediction_length = getattr(configs, 'prediction_length', self.pred_len)
        self.model_path = getattr(configs, 'timesfm_model_path', f'{ROOTPATH}/timesfm-1.0-200m-pytorch/torch_model.ckpt')
        
        self.input_patch_len = getattr(configs, 'input_patch_len', 32)
        
        if self.context_length % self.input_patch_len != 0:
            new_context_length = ((self.context_length - 1) // self.input_patch_len + 1) * self.input_patch_len
            print(f"Adjusting context_length from {self.context_length} to {new_context_length} to be divisible by input_patch_len {self.input_patch_len}")
            self.context_length = new_context_length

        self.output_patch_len = getattr(configs, 'output_patch_len', 128)
        self.num_layers = getattr(configs, 'num_layers', 20)
        self.model_dims = getattr(configs, 'model_dims', 1280)
        
        self.frequency_indicator = getattr(configs, 'frequency_indicator', 0)
        self.enable_plot = getattr(configs, 'enable_plot', True)
        self.plot_dir = getattr(configs, 'plot_dir', './plots_timesfm')
        self.plot_counter = 0
        
        if self.enable_plot:
            os.makedirs(self.plot_dir, exist_ok=True)
            
        self.timesfm_model = None
        
        if self.seq_len > 512:
            print(f"Warning: seq_len {self.seq_len} exceeds TimesFM 1.0's max context length 512, will truncate")
            self.effective_seq_len = 512
        else:
            self.effective_seq_len = self.seq_len
        
    def _initialize_model(self):
        if self.timesfm_model is not None:
            return
        
        try:
            from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint

            hparams = TimesFmHparams(
                context_len=self.context_length,
                per_core_batch_size=32,
                input_patch_len=self.input_patch_len,
                output_patch_len=self.output_patch_len,
                num_layers=self.num_layers,
                model_dims=self.model_dims,
                horizon_len=self.prediction_length,
                backend='gpu' if torch.cuda.is_available() else 'cpu',
            )
            checkpoint = TimesFmCheckpoint(path=self.model_path)
            
            self.timesfm_model = TimesFm(hparams=hparams, checkpoint=checkpoint)
            
            if hasattr(self.timesfm_model, '_model'):
                self.inner_model = self.timesfm_model._model
            
            print(f"Successfully loaded TimesFM model from {self.model_path}")
            print(f"Max context length: {self.context_length}")
            print(f"Max horizon length: {self.prediction_length}")
            
        except Exception as e:
            print(f"Failed to load TimesFM model: {e}")
            raise e
    
    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
        if self.timesfm_model is None:
            self._initialize_model()
            
        result = self._forward_pretrained(x_enc)
        
        if self.enable_plot:
            self._plot_prediction(x_enc, result)
        
        return result
    
    def _forward_pretrained(self, x_enc):
        batch_size, seq_len, feature_dim = x_enc.shape
        
        if self.timesfm_model is None:
            raise RuntimeError("TimesFM model not initialized. Call _initialize_model() first.")
        
        if seq_len > self.effective_seq_len:
            x_enc = x_enc[:, -self.effective_seq_len:, :]
            seq_len = self.effective_seq_len
        
        if seq_len % self.input_patch_len != 0:
            pad_len = self.input_patch_len - (seq_len % self.input_patch_len)
            zeros = torch.zeros(batch_size, pad_len, feature_dim, device=x_enc.device, dtype=x_enc.dtype)
            x_enc = torch.cat([zeros, x_enc], dim=1)
            seq_len = x_enc.shape[1]
        
        x_reshaped = x_enc.permute(0, 2, 1).reshape(batch_size * feature_dim, seq_len)  # [B*D, L]
        
        forecast_input = []
        series_mean, series_var = None, None
        
        for i in range(batch_size * feature_dim):
            forecast_input.append(x_reshaped[i].cpu().numpy())
        
        with torch.no_grad():
            try:
                freq_input = [self.frequency_indicator] * (batch_size * feature_dim)
                
                point_forecast, _ = self.timesfm_model.forecast(
                    forecast_input,
                    freq=freq_input,
                )
                
                if isinstance(point_forecast, list):
                    pred_part = torch.tensor(point_forecast, device=x_enc.device, dtype=x_enc.dtype)
                else:
                    pred_part = torch.from_numpy(point_forecast).to(device=x_enc.device, dtype=x_enc.dtype)
                
                if pred_part.dim() == 1:
                    pred_part = pred_part.unsqueeze(0)
                if pred_part.shape[1] != self.prediction_length:
                    if pred_part.shape[1] > self.prediction_length:
                        pred_part = pred_part[:, :self.prediction_length]
                    else:
                        pad_size = self.prediction_length - pred_part.shape[1]
                        last_values = pred_part[:, -1:].expand(-1, pad_size)
                        pred_part = torch.cat([pred_part, last_values], dim=1)
                
            except Exception as e:
                print(f"TimesFM 1.0 forecast failed: {e}")
                pred_part = x_reshaped[:, -1:].expand(-1, self.prediction_length)
        
        predictions = pred_part.reshape(batch_size, feature_dim, self.prediction_length).permute(0, 2, 1)
        
        return predictions
    
    def _plot_prediction(self, x_enc, predictions):
        try:
            batch_size, seq_len, feature_dim = x_enc.shape
            pred_len = predictions.shape[1]
            
            batch_idx = 0
            max_features_to_plot = min(feature_dim, 4)
            
            fig, axes = plt.subplots(max_features_to_plot, 1, figsize=(12, 3 * max_features_to_plot))
            if max_features_to_plot == 1:
                axes = [axes]
            
            for feat_idx in range(max_features_to_plot):
                ax = axes[feat_idx]
                input_data = x_enc[batch_idx, :, feat_idx].cpu().numpy()
                pred_data = predictions[batch_idx, :, feat_idx].cpu().numpy()
                time_input = np.arange(seq_len)
                time_pred = np.arange(seq_len, seq_len + pred_len)
                ax.plot(time_input, input_data, 'b-', label='Input (Historical)', linewidth=2)
                ax.plot(time_pred, pred_data, 'r-', label='TimesFM Prediction', linewidth=2)
                ax.plot([seq_len-1, seq_len], [input_data[-1], pred_data[0]], 'g--', alpha=0.7)
                ax.set_title(f'Feature {feat_idx + 1}: TimesFM 1.0 Input vs Prediction (freq={self.frequency_indicator})')
                ax.set_xlabel('Time Steps')
                ax.set_ylabel('Value')
                ax.legend()
                ax.grid(True, alpha=0.3)
                
                ax.axvline(x=seq_len, color='gray', linestyle='--', alpha=0.5, label='Prediction Start')
            
            plt.tight_layout()
            
            plot_filename = f'timesfm_prediction_{self.plot_counter:04d}.png'
            plot_path = os.path.join(self.plot_dir, plot_filename)
            plt.savefig(plot_path, dpi=150, bbox_inches='tight')
            plt.close()
            
            self.plot_counter += 1
            print(f"TimesFM prediction plot saved to: {plot_path}")
            
        except Exception as e:
            print(f"Failed to create TimesFM prediction plot: {e}")
            plt.close('all')


class Model(TimesFMWrapper):
    pass
