import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Dict, Any
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from .ROOTPATH import ROOTPATH


class MoiraiWrapper(nn.Module):
    def __init__(self, configs):
        super(MoiraiWrapper, self).__init__()
        
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len  
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        
        self.context_length = getattr(configs, 'context_length', self.seq_len)
        self.prediction_length = getattr(configs, 'prediction_length', self.pred_len)
        self.model_path = getattr(configs, 'moirai_model_path', f'{ROOTPATH}/moirai-2.0-R-small')
        self.patch_size = getattr(configs, 'patch_size', 'auto')
        self.num_samples = getattr(configs, 'num_samples', 100)
        self.batch_size = getattr(configs, 'batch_size', 32)
        
        self.enable_plot = getattr(configs, 'enable_plot', True)
        self.plot_dir = getattr(configs, 'plot_dir', './plots_moirai')
        self.plot_counter = 0
        
        if self.enable_plot:
            os.makedirs(self.plot_dir, exist_ok=True)
        self.moirai_model = None
        self.predictor = None

        
    def _initialize_model(self):
        if self.moirai_model is not None:
            return
        
        from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module
        self.moirai_model = Moirai2Forecast(
            module=Moirai2Module.from_pretrained(self.model_path),
            prediction_length=self.prediction_length,
            context_length=self.context_length,
            target_dim=1,
            feat_dynamic_real_dim=0,
            past_feat_dynamic_real_dim=0,
        )
        self.predictor = self.moirai_model.create_predictor(batch_size=self.batch_size)
        
        if hasattr(self.moirai_model, 'module'):
            self.inner_model = self.moirai_model.module
            
        print(f"Successfully loaded MOIRAI 2.0 model from {self.model_path}")
    

    def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
        if self.moirai_model is None:
            self._initialize_model()
        result = self._forward_pretrained(x_enc)
        
        if self.enable_plot:
            self._plot_prediction(x_enc, result)
        
        return result
    
    def _forward_pretrained(self, x_enc):
        batch_size, seq_len, feature_dim = x_enc.shape
        
        if self.predictor is None:
            raise RuntimeError("MOIRAI predictor not initialized. Call _initialize_model() first.")
        
        x_reshaped = x_enc.permute(0, 2, 1).reshape(batch_size * feature_dim, seq_len)  # [B*D, L]
        
        data_items = []
        for i in range(batch_size * feature_dim):
            ts_data = x_reshaped[i].cpu().numpy()
            
            data_item = {
                'target': ts_data,
                'start': pd.Period('2023-01-01', freq='D'),
                'item_id': f'series_{i}',
            }
            data_items.append(data_item)
        
        forecast_iter = self.predictor.predict(data_items)
        
        batch_predictions = []
        for forecast in forecast_iter:
            pred_mean = forecast.quantile(0.5)
            batch_predictions.append(pred_mean)
        
        predictions_tensor = torch.tensor(batch_predictions, device=x_enc.device, dtype=x_enc.dtype)
        output = predictions_tensor.reshape(batch_size, feature_dim, self.pred_len).permute(0, 2, 1)
        
        return output
    
    def _plot_prediction(self, x_enc, predictions):
        try:
            batch_size, seq_len, feature_dim = x_enc.shape
            pred_len = predictions.shape[1]
            
            batch_idx = 0
            max_features_to_plot = min(feature_dim, 4)
            
            fig, axes = plt.subplots(max_features_to_plot, 1, figsize=(12, 3 * max_features_to_plot))
            if max_features_to_plot == 1:
                axes = [axes]
            
            for feat_idx in range(max_features_to_plot):
                ax = axes[feat_idx]
                
                input_data = x_enc[batch_idx, :, feat_idx].cpu().numpy()
                pred_data = predictions[batch_idx, :, feat_idx].cpu().numpy()
                time_input = np.arange(seq_len)
                time_pred = np.arange(seq_len, seq_len + pred_len)
                ax.plot(time_input, input_data, 'b-', label='Input (Historical)', linewidth=2)
                ax.plot(time_pred, pred_data, 'r-', label='Prediction', linewidth=2)
                ax.plot([seq_len-1, seq_len], [input_data[-1], pred_data[0]], 'g--', alpha=0.7)
                
                ax.set_title(f'Feature {feat_idx + 1}: Input vs Prediction')
                ax.set_xlabel('Time Steps')
                ax.set_ylabel('Value')
                ax.legend()
                ax.grid(True, alpha=0.3)
                
                ax.axvline(x=seq_len, color='gray', linestyle='--', alpha=0.5, label='Prediction Start')
            
            plt.tight_layout()
            
            plot_filename = f'moirai_prediction_{self.plot_counter:04d}.png'
            plot_path = os.path.join(self.plot_dir, plot_filename)
            plt.savefig(plot_path, dpi=150, bbox_inches='tight')
            plt.close()
            
            self.plot_counter += 1
            print(f"Prediction plot saved to: {plot_path}")
            
        except Exception as e:
            print(f"Failed to create prediction plot: {e}")
            plt.close('all')
    

class Model(MoiraiWrapper):
    pass
