import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Dict, Any
import sys
import os
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from .ROOTPATH import ROOTPATH


class TimerWrapper(nn.Module):
    def __init__(self, configs):
        super(TimerWrapper, self).__init__()
        
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len  
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        
        self.context_length = getattr(configs, 'context_length', min(2880, self.seq_len))
        self.prediction_length = getattr(configs, 'prediction_length', self.pred_len)
        self.model_path = getattr(configs, 'timer_model_path', f'{ROOTPATH}/timer-base-84m')
        self.patch_length = getattr(configs, 'patch_length', 96)
        self.max_new_tokens = getattr(configs, 'max_new_tokens', self.pred_len)
        
        self.enable_plot = getattr(configs, 'enable_plot', True)
        self.plot_dir = getattr(configs, 'plot_dir', './plots_timer')
        self.plot_counter = 0
        
        if self.enable_plot:
            os.makedirs(self.plot_dir, exist_ok=True)
        
        self.timer_model = None
        
        if self.seq_len > 2880:
            print(f"Warning: seq_len {self.seq_len} exceeds Timer's max context length 2880, will truncate")
            self.effective_seq_len = 2880
        else:
            self.effective_seq_len = self.seq_len
        
    def _initialize_model(self):
        if self.timer_model is not None:
            return
        
        try:
            from transformers import AutoModelForCausalLM
            
            self.timer_model = AutoModelForCausalLM.from_pretrained(
                self.model_path, 
                trust_remote_code=True
            )
            
            if torch.cuda.is_available():
                self.timer_model = self.timer_model.cuda()
            
            print(f"Successfully loaded Timer model from {self.model_path}")
            print(f"Model parameters: {sum(p.numel() for p in self.timer_model.parameters() if p.requires_grad):,}")
            
            self.inner_model = self.timer_model
            
        except Exception as e:
            print(f"Failed to load Timer model: {e}")
            raise e
    
    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
        if self.timer_model is None:
            self._initialize_model()
            
        result = self._forward_pretrained(x_enc)
        
        if self.enable_plot:
            self._plot_prediction(x_enc, result)
        
        return result
    
    def _forward_pretrained(self, x_enc):
        batch_size, seq_len, feature_dim = x_enc.shape
        
        if seq_len > self.effective_seq_len:
            x_enc = x_enc[:, -self.effective_seq_len:, :]
            seq_len = self.effective_seq_len
        
        x_reshaped = x_enc.permute(0, 2, 1).reshape(batch_size * feature_dim, seq_len)  # [B*D, L]
        
        with torch.no_grad():
            output = self.timer_model.generate(
                x_reshaped, 
                max_new_tokens=self.prediction_length
            )
            
            pred_part = output[:, -self.prediction_length:]  # [B*D, pred_len]
        
        predictions = pred_part.reshape(batch_size, feature_dim, self.prediction_length).permute(0, 2, 1)
        
        return predictions
    
    def _plot_prediction(self, x_enc, predictions):
        try:
            batch_size, seq_len, feature_dim = x_enc.shape
            pred_len = predictions.shape[1]
            
            batch_idx = 0
            max_features_to_plot = min(feature_dim, 4)
            
            fig, axes = plt.subplots(max_features_to_plot, 1, figsize=(12, 3 * max_features_to_plot))
            if max_features_to_plot == 1:
                axes = [axes]
            
            for feat_idx in range(max_features_to_plot):
                ax = axes[feat_idx]
                
                input_data = x_enc[batch_idx, :, feat_idx].cpu().numpy()
                pred_data = predictions[batch_idx, :, feat_idx].cpu().numpy()
                time_input = np.arange(seq_len)
                time_pred = np.arange(seq_len, seq_len + pred_len)
                ax.plot(time_input, input_data, 'b-', label='Input (Historical)', linewidth=2)
                ax.plot(time_pred, pred_data, 'r-', label='Timer Prediction', linewidth=2)
                ax.plot([seq_len-1, seq_len], [input_data[-1], pred_data[0]], 'g--', alpha=0.7)
                ax.set_title(f'Feature {feat_idx + 1}: Timer Input vs Prediction')
                ax.set_xlabel('Time Steps')
                ax.set_ylabel('Value')
                ax.legend()
                ax.grid(True, alpha=0.3)
                
                ax.axvline(x=seq_len, color='gray', linestyle='--', alpha=0.5, label='Prediction Start')
            
            plt.tight_layout()
            
            plot_filename = f'timer_prediction_{self.plot_counter:04d}.png'
            plot_path = os.path.join(self.plot_dir, plot_filename)
            plt.savefig(plot_path, dpi=150, bbox_inches='tight')
            plt.close()
            
            self.plot_counter += 1
            print(f"Timer prediction plot saved to: {plot_path}")
            
        except Exception as e:
            print(f"Failed to create Timer prediction plot: {e}")
            plt.close('all')


class Model(TimerWrapper):
    pass
