import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Dict, Any
import sys
import os
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from .ROOTPATH import ROOTPATH


class TimeMoEWrapper(nn.Module):
    def __init__(self, configs):
        super(TimeMoEWrapper, self).__init__()
        
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len  
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        
        self.context_length = getattr(configs, 'context_length', self.seq_len)
        self.prediction_length = getattr(configs, 'prediction_length', self.pred_len)
        self.model_path = getattr(configs, 'timemoe_model_path', f'{ROOTPATH}/TimeMoE-200M')
        self.max_new_tokens = getattr(configs, 'max_new_tokens', self.pred_len)
        self.device_map = getattr(configs, 'device_map', 'cuda')  # 'cpu' 或 'cuda'
        
        self.enable_plot = getattr(configs, 'enable_plot', True)
        self.plot_dir = getattr(configs, 'plot_dir', './plots_timemoe')
        self.plot_counter = 0
        
        if self.enable_plot:
            os.makedirs(self.plot_dir, exist_ok=True)
            
        self.timemoe_model = None
        
    def _initialize_model(self):
        if self.timemoe_model is not None:
            return
        
        try:
            from transformers import AutoModelForCausalLM
            self.timemoe_model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                device_map=self.device_map,
                attn_implementation='flash_attention_2',
                trust_remote_code=True
            )
            
            if self.device_map == 'cuda' and not torch.cuda.is_available():
                print("Warning: CUDA not available, falling back to CPU")
                self.device_map = 'cpu'
                self.timemoe_model = AutoModelForCausalLM.from_pretrained(
                    self.model_path,
                    device_map='cpu',
                    trust_remote_code=True
                )
            
            print(f"Successfully loaded TimeMoE model from {self.model_path}")
            print(f"Device mapping: {self.device_map}")
            print(f"Model parameters: {sum(p.numel() for p in self.timemoe_model.parameters() if p.requires_grad):,}")
            
            self.inner_model = self.timemoe_model
            
        except Exception as e:
            print(f"Failed to load TimeMoE model: {e}")
            raise e
    
    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
        if self.timemoe_model is None:
            self._initialize_model()
            
        result = self._forward_pretrained(x_enc)
        
        if self.enable_plot:
            self._plot_prediction(x_enc, result)
        
        return result
    
    def _forward_pretrained(self, x_enc):
        batch_size, seq_len, feature_dim = x_enc.shape
        
        if seq_len > self.context_length:
            x_enc = x_enc[:, -self.context_length:, :]
            seq_len = self.context_length
        elif seq_len < self.context_length:
            pad_length = self.context_length - seq_len
            last_values = x_enc[:, -1:, :].repeat(1, pad_length, 1)
            x_enc = torch.cat([x_enc, last_values], dim=1)
            seq_len = self.context_length
        
        x_reshaped = x_enc.permute(0, 2, 1).reshape(batch_size * feature_dim, seq_len)  # [B*D, L]
        
        try:
            device = next(self.timemoe_model.parameters()).device
        except StopIteration:
            device = x_enc.device
        x_reshaped = x_reshaped.to(device)
        
        with torch.no_grad():
            output = self.timemoe_model.generate(
                x_reshaped, 
                max_new_tokens=self.prediction_length
            )
            pred_part = output[:, -self.prediction_length:]  # [B*D, pred_len]
        
        predictions = pred_part.reshape(batch_size, feature_dim, self.prediction_length).permute(0, 2, 1)
        
        if predictions.device != x_enc.device:
            predictions = predictions.to(x_enc.device)
        
        return predictions
    
    def _plot_prediction(self, x_enc, predictions):
        try:
            batch_size, seq_len, feature_dim = x_enc.shape
            pred_len = predictions.shape[1]
            
            batch_idx = 0
            max_features_to_plot = min(feature_dim, 4)
            
            fig, axes = plt.subplots(max_features_to_plot, 1, figsize=(12, 3 * max_features_to_plot))
            if max_features_to_plot == 1:
                axes = [axes]
            
            for feat_idx in range(max_features_to_plot):
                ax = axes[feat_idx]
                
                input_data = x_enc[batch_idx, :, feat_idx].detach().cpu().numpy()
                pred_data = predictions[batch_idx, :, feat_idx].detach().cpu().numpy()
                time_input = np.arange(seq_len)
                time_pred = np.arange(seq_len, seq_len + pred_len)
                ax.plot(time_input, input_data, 'b-', label='Input (Historical)', linewidth=2)
                ax.plot(time_pred, pred_data, 'r-', label='TimeMoE Prediction', linewidth=2)
                ax.plot([seq_len-1, seq_len], [input_data[-1], pred_data[0]], 'g--', alpha=0.7)
                
                ax.set_title(f'Feature {feat_idx + 1}: TimeMoE Input vs Prediction')
                ax.set_xlabel('Time Steps')
                ax.set_ylabel('Value')
                ax.legend()
                ax.grid(True, alpha=0.3)
                
                ax.axvline(x=seq_len, color='gray', linestyle='--', alpha=0.5, label='Prediction Start')
            
            plt.tight_layout()
            
            plot_filename = f'timemoe_prediction_{self.plot_counter:04d}.png'
            plot_path = os.path.join(self.plot_dir, plot_filename)
            plt.savefig(plot_path, dpi=150, bbox_inches='tight')
            plt.close()
            
            self.plot_counter += 1
            print(f"TimeMoE prediction plot saved to: {plot_path}")
            
        except Exception as e:
            print(f"Failed to create TimeMoE prediction plot: {e}")
            plt.close('all')


class Model(TimeMoEWrapper):
    pass
