import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from .base import ColumnVectorizer
from datetime import datetime
import torch.nn.functional as F
import random
from datetime import timedelta

class DateTimeVectorizer(ColumnVectorizer):
    def __init__(self, output_dim, hidden_dim=128):
        """
        Initialize the datetime vectorizer.
        
        Args:
            output_dim (int): The dimension of the output vectors (D).
            hidden_dim (int): Hidden dimension for the MLP encoder/decoder.
        """
        super().__init__(output_dim=output_dim, accepted_dtype=[
            "datetime64[ns]", "datetime64[ms]", "datetime64[us]", "datetime64[s]",
            "datetime64[m]", "datetime64[h]", "datetime64[D]", "datetime64[W]",
            "datetime64[M]", "datetime64[Y]", "timedelta64[ns]",
            "object"  # For string datetime representations
        ])
        
        # Constants for cyclic encoding
        self.MONTH_PERIOD = 12
        self.DAY_PERIOD = 31
        self.HOUR_PERIOD = 24
        self.MINUTE_PERIOD = 60  # Added minute period
        
        # Input dimension: year + 2*(month + day + hour + minute) = 1 + 2*4 = 9
        self.input_dim = 9  # Updated to include minute
        
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(self.input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
        # Decoder network outputs raw values for year and logits for month/day/hour/minute
        self.decoder = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Separate heads for different components
        self.year_head = nn.Linear(hidden_dim, 1)  # Direct year prediction
        self.month_head = nn.Linear(hidden_dim, self.MONTH_PERIOD)  # Month classification
        self.day_head = nn.Linear(hidden_dim, self.DAY_PERIOD)  # Day classification
        self.hour_head = nn.Linear(hidden_dim, self.HOUR_PERIOD)  # Hour classification
        self.minute_head = nn.Linear(hidden_dim, self.MINUTE_PERIOD)  # Minute classification
        
        self._init_weights()

    def _init_weights(self):
        """Initialize network weights"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def is_trainable(self):
        return True

    def to(self, device):
        """Override to() to ensure all networks are moved to the correct device"""
        super().to(device)
        self.encoder = self.encoder.to(device)
        self.decoder = self.decoder.to(device)
        self.year_head = self.year_head.to(device)
        self.month_head = self.month_head.to(device)
        self.day_head = self.day_head.to(device)
        self.hour_head = self.hour_head.to(device)
        self.minute_head = self.minute_head.to(device)  # Added minute head
        return self

    def _cyclic_encode(self, x, period):
        """Apply cyclic encoding using sine and cosine"""
        x = x.float()
        sin_x = torch.sin(2 * np.pi * x / period)
        cos_x = torch.cos(2 * np.pi * x / period)
        return torch.stack([sin_x, cos_x], dim=-1)

    def _extract_datetime_components(self, dt_series):
        """Extract year, month, day, hour from datetime series"""
        # Ensure input is datetime
        if not pd.api.types.is_datetime64_any_dtype(dt_series):
            try:
                # Try to convert to datetime
                dt_series = pd.to_datetime(dt_series, format='mixed', errors='coerce')
            except Exception as e:
                print(f"Warning: Some values could not be converted to datetime: {str(e)}")
                dt_series = pd.to_datetime(dt_series, format='mixed', errors='coerce')
        
        # Handle missing values (NaT)
        if dt_series.isna().any():
            # If all values are missing, use a default datetime
            if dt_series.isna().all():
                default_dt = pd.Timestamp('2000-01-01 00:00:00')
                dt_series = dt_series.fillna(default_dt)
            else:
                # Impute missing values with mean of non-missing values
                mean_dt = dt_series[~dt_series.isna()].mean()
                dt_series = dt_series.fillna(mean_dt)
            
            print(f"Warning: Imputed {dt_series.isna().sum()} missing datetime values")
        
        # Extract components
        years = torch.tensor(dt_series.dt.year.values, device=self.device)
        months = torch.tensor(dt_series.dt.month.values-1, device=self.device)
        days = torch.tensor(dt_series.dt.day.values-1, device=self.device)
        hours = torch.tensor(dt_series.dt.hour.values, device=self.device)
        minutes = torch.tensor(dt_series.dt.minute.values, device=self.device)

        return years, months, days, hours, minutes

    def _vectorize(self, column, config):
        """Transform datetime column into embeddings"""
        # Extract components
        years, months, days, hours, minutes = self._extract_datetime_components(column)
        
        # Normalize year to handle the scale
        year_normalized = (years - 2000) / 100  # Centering around 2000
        
        # Apply cyclic encoding to periodic components
        month_enc = self._cyclic_encode(months, self.MONTH_PERIOD)
        day_enc = self._cyclic_encode(days, self.DAY_PERIOD)
        hour_enc = self._cyclic_encode(hours, self.HOUR_PERIOD)
        minute_enc = self._cyclic_encode(minutes, self.MINUTE_PERIOD)  # Added minute encoding
        
        # Combine all features
        features = torch.cat([
            year_normalized.unsqueeze(-1),
            month_enc.reshape(len(column), -1),
            day_enc.reshape(len(column), -1),
            hour_enc.reshape(len(column), -1),
            minute_enc.reshape(len(column), -1)  # Added minute features
        ], dim=-1)
        
        # Encode through MLP
        return self.encoder(features)

    def _compute_loss(self, reconstructed_values, target_column, config):
        """
        Compute loss for datetime reconstruction using cross entropy for month, day, hour, minute, 
        and MSE for year.

        Args:
            reconstructed_values (dict): Dictionary containing predictions for each component
            target_column (pd.Series): Series of target datetime values

        Returns:
            torch.Tensor: Combined loss from all components
        """
        # Extract target components
        target_y, target_m, target_d, target_h, target_min = self._extract_datetime_components(target_column)
        
        # Convert targets to correct dtype and device
        target_y = target_y.float().to(self.device)
        target_m = target_m.long().to(self.device)  # Convert to long for cross_entropy
        target_d = target_d.long().to(self.device)  # Convert to long for cross_entropy
        target_h = target_h.long().to(self.device)  # Convert to long for cross_entropy
        target_min = target_min.long().to(self.device)  # Convert to long for cross_entropy
        
        # Compute losses for each component
        year_loss = F.mse_loss(reconstructed_values['year'].squeeze(), target_y)
        month_loss = F.cross_entropy(reconstructed_values['month_logits'], target_m)
        day_loss = F.cross_entropy(reconstructed_values['day_logits'], target_d)
        hour_loss = F.cross_entropy(reconstructed_values['hour_logits'], target_h)
        minute_loss = F.cross_entropy(reconstructed_values['minute_logits'], target_min)  # Added minute loss
        
        # Combine losses (you might want to weight them differently)
        total_loss = year_loss + month_loss + day_loss + hour_loss + minute_loss
        
        return total_loss

    def _inverse_vectorize(self, tensor, config, mode='inference'):
        """
        Decode embeddings back to datetime components.
        In train mode, returns the decoded components for loss computation.
        In inference mode, returns reconstructed datetime objects.

        Args:
            tensor (torch.Tensor): The tensor to inverse transform.
            config (dict): Configuration dictionary.
            mode (str): 'inference' or 'train' mode.

        Returns:
            Union[pd.Series, dict]: 
                - If mode='inference': Reconstructed datetime objects as pandas Series
                - If mode='train': Dictionary containing decoded components for loss computation
        """
        # Get shared features
        features = self.decoder(tensor)
        
        # Get component predictions
        year_pred = self.year_head(features)  # Raw year value
        month_logits = self.month_head(features)  # Month classification
        day_logits = self.day_head(features)  # Day classification
        hour_logits = self.hour_head(features)  # Hour classification
        minute_logits = self.minute_head(features)  # Minute classification
        
        if mode == "train":
            return {
                'year': year_pred * 100 + 2000,  # Denormalize
                'month_logits': month_logits,
                'day_logits': day_logits,
                'hour_logits': hour_logits,
                'minute_logits': minute_logits  # Added minute logits
            }
        
        # For inference, convert logits to datetime
        month = torch.argmax(month_logits, dim=-1) + 1  # 1-based
        day = torch.argmax(day_logits, dim=-1) + 1  # 1-based
        hour = torch.argmax(hour_logits, dim=-1)
        minute = torch.argmax(minute_logits, dim=-1)  # Added minute
        year = (year_pred * 100 + 2000).round().long()
        
        # Convert to datetime objects
        datetimes = []
        for i in range(len(year)):
            y_val = int(year[i].cpu().item())
            m_val = int(month[i].cpu().item())
            d_val = int(day[i].cpu().item())
            h_val = int(hour[i].cpu().item())
            min_val = int(minute[i].cpu().item())  # Added minute
            
            try:
                dt = datetime(y_val, m_val, d_val, h_val, min_val)  # Added minute
                datetimes.append(dt)
            except ValueError:
                # Handle invalid dates (e.g., February 30)
                datetimes.append(pd.NaT)
        
        return pd.Series(datetimes)

if __name__ == "__main__":
    # Create sample datetime data
    dates = pd.Series([
        pd.Timestamp('2015-03-28T16:04:48'),
        pd.Timestamp('2021-06-22 14:45:00'),
        pd.Timestamp('2019-12-31 23:59:00'),
        pd.Timestamp('2022-03-10 12:00:00'),
        pd.Timestamp('2018-07-04 18:15:00')
    ])

    datetime_series = pd.Series([
        str('2015-03-28T16:04:48'),
        str('2021-06-22T14:45:00'),
        str('2019-12-31T23:59:00'),
        str('2022-03-10T12:00:00'),
        str('2018-07-04T18:15:00')
    ])
    datetime_series = pd.to_datetime(datetime_series, format='mixed', errors='raise')
    print(datetime_series)
    # Initialize vectorizer
    output_dim = 64
    vectorizer = DateTimeVectorizer(output_dim=output_dim, hidden_dim=128)
    vectorizer.to("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"\n{'='*50}")
    print(f"Testing DateTimeVectorizer")
    print(f"{'='*50}")
    
    print("Original datetimes:")
    print(dates)
    print("\n")
    
    # Vectorize
    print("Vectorizing datetimes...")
    embeddings = vectorizer._vectorize(dates, config={})
    print(f"Embedding shape: {embeddings.shape}")
    print(f"Embedding sample (first 5 dimensions):\n{embeddings[0, :5]}")
    print("\n")
    
    # Test training mode reconstruction
    print("Testing training mode reconstruction...")
    train_output = vectorizer._inverse_vectorize(embeddings, config={}, mode='train')
    print(f"Training mode output keys: {train_output.keys()}")
    print(f"Year predictions shape: {train_output['year'].shape}")
    print(f"Month logits shape: {train_output['month_logits'].shape}")
    print(f"Day logits shape: {train_output['day_logits'].shape}")
    print(f"Hour logits shape: {train_output['hour_logits'].shape}")
    print(f"Minute logits shape: {train_output['minute_logits'].shape}")  # Added minute output
    
    # Test loss computation
    print("\nTesting loss computation...")
    loss = vectorizer._compute_loss(train_output, dates, config={})
    print(f"Reconstruction loss: {loss.item()}")
    
    # Test inference mode
    print("\nTesting inference mode reconstruction...")
    reconstructed_dates = vectorizer._inverse_vectorize(embeddings, config={}, mode='inference')
    
    # Compare results
    print("\nReconstruction results:")
    for original, reconstructed in zip(dates, reconstructed_dates):
        print(f"Original: {original} | Reconstructed: {reconstructed}")
        
    # Test accuracy
    year_accuracy = sum(original.year == reconstructed.year for original, reconstructed in zip(dates, reconstructed_dates) if pd.notna(reconstructed)) / len(dates)
    month_accuracy = sum(original.month == reconstructed.month for original, reconstructed in zip(dates, reconstructed_dates) if pd.notna(reconstructed)) / len(dates)
    day_accuracy = sum(original.day == reconstructed.day for original, reconstructed in zip(dates, reconstructed_dates) if pd.notna(reconstructed)) / len(dates)
    hour_accuracy = sum(original.hour == reconstructed.hour for original, reconstructed in zip(dates, reconstructed_dates) if pd.notna(reconstructed)) / len(dates)
    minute_accuracy = sum(original.minute == reconstructed.minute for original, reconstructed in zip(dates, reconstructed_dates) if pd.notna(reconstructed)) / len(dates)  # Added minute accuracy
    
    print("\nComponent accuracy:")
    print(f"Year accuracy: {year_accuracy:.2%}")
    print(f"Month accuracy: {month_accuracy:.2%}")
    print(f"Day accuracy: {day_accuracy:.2%}")
    print(f"Hour accuracy: {hour_accuracy:.2%}")
    print(f"Minute accuracy: {minute_accuracy:.2%}")  # Added minute accuracy

    # Test batch processing for datetime vectorization
    print("\n" + "="*50)
    print("Testing Batch Processing for DateTimeVectorizer")
    print("="*50)
    
    # Create multiple columns of datetime data
    datetime_columns = [
        pd.Series([
            pd.Timestamp('2015-03-28T16:04:48'),
            pd.Timestamp('2021-06-22T14:45:00'),
            pd.Timestamp('2019-12-31T23:59:00')
        ]),
        pd.Series([
            pd.Timestamp('2022-03-10T12:00:00'),
            pd.Timestamp('2018-07-04T18:15:00'),
            pd.Timestamp('2020-01-01T00:00:00'),
            pd.Timestamp('2017-11-23T08:30:00')
        ]),
        pd.Series([
            pd.Timestamp('2016-02-29T09:20:10'),  # Leap year
            pd.Timestamp('2023-05-15T13:45:30')
        ])
    ]
    configs = [{}, {}, {}]  # Empty configs since datetime vectorizer doesn't require config
    
    # Create vectorizer
    vectorizer = DateTimeVectorizer(output_dim=64, hidden_dim=128)
    vectorizer.to("cuda" if torch.cuda.is_available() else "cpu")
    
    # Test 1: Batch vectorization
    print("\nTest 1: Batch Vectorization")
    import time
    
    # Time batch processing
    start_time = time.time()
    batch_vectors = vectorizer.vectorize_batch(datetime_columns, configs)
    batch_time = time.time() - start_time
    
    print(f"Number of output tensors: {len(batch_vectors)}")
    for i, tensor in enumerate(batch_vectors):
        print(f"Tensor {i} shape: {tensor.shape}")
    
    # Time sequential processing for comparison
    start_time = time.time()
    sequential_vectors = [vectorizer.vectorize(col, {}) for col in datetime_columns]
    sequential_time = time.time() - start_time
    
    print(f"Sequential processing time: {sequential_time:.4f}s")
    print(f"Batch processing time: {batch_time:.4f}s")
    print(f"Speed ratio: {sequential_time / batch_time:.2f}x")
    
    # Test 2: Batch inverse vectorization (inference mode)
    print("\nTest 2: Batch Inverse Vectorization (Inference Mode)")
    
    # Time batch processing
    start_time = time.time()
    batch_reconstructed = vectorizer.inverse_vectorize_batch(batch_vectors, configs)
    batch_inverse_time = time.time() - start_time
    
    # Time sequential processing
    start_time = time.time()
    sequential_reconstructed = [vectorizer.inverse_vectorize(vec, {}) for vec in sequential_vectors]
    sequential_inverse_time = time.time() - start_time
    
    print(f"Sequential inverse time: {sequential_inverse_time:.4f}s")
    print(f"Batch inverse time: {batch_inverse_time:.4f}s")
    print(f"Speed ratio: {sequential_inverse_time / batch_inverse_time:.2f}x")
    
    # Verify results match between batch and sequential processing
    print("\nVerifying batch and sequential results match:")
    all_match = True
    
    for i, (batch_col, seq_col) in enumerate(zip(batch_reconstructed, sequential_reconstructed)):
        if len(batch_col) != len(seq_col):
            print(f"❌ Column {i}: Length mismatch - batch: {len(batch_col)}, sequential: {len(seq_col)}")
            all_match = False
            continue
            
        # Check if datetimes match
        match_count = sum(
            pd.Timestamp(b) == pd.Timestamp(s) 
            for b, s in zip(batch_col, seq_col)
            if pd.notna(b) and pd.notna(s)
        )
        
        match_percent = match_count / len(batch_col) * 100
        if match_percent < 100:
            print(f"❌ Column {i}: Only {match_percent:.2f}% of values match")
            all_match = False
        else:
            print(f"✓ Column {i}: All values match")
    
    if all_match:
        print("✅ All batch results match sequential processing")
    else:
        print("❌ Some batch results differ from sequential processing")
    
    # Test 3: Batch inverse vectorization (training mode)
    print("\nTest 3: Batch Inverse Vectorization (Training Mode)")
    
    # Time batch processing
    start_time = time.time()
    batch_train_outputs, batch_loss = vectorizer.inverse_vectorize_batch(
        batch_vectors, configs, mode='train', target_columns=datetime_columns
    )
    batch_train_time = time.time() - start_time
    
    # Time sequential processing
    start_time = time.time()
    sequential_train_outputs = []
    sequential_losses = []
    
    for vec, col in zip(sequential_vectors, datetime_columns):
        output, loss = vectorizer.inverse_vectorize(vec, {}, mode='train', target_column=col)
        sequential_train_outputs.append(output)
        sequential_losses.append(loss)
    
    sequential_train_time = time.time() - start_time
    total_sequential_loss = sum(loss for loss in sequential_losses)
    
    print(f"Sequential training time: {sequential_train_time:.4f}s")
    print(f"Batch training time: {batch_train_time:.4f}s")
    print(f"Speed ratio: {sequential_train_time / batch_train_time:.2f}x")
    print(f"Batch loss: {batch_loss.item():.6f}")
    print(f"Sum of sequential losses: {total_sequential_loss.item():.6f}")
    print(f"Loss difference: {abs(batch_loss.item() - total_sequential_loss.item()):.6f}")
    
    # Test 4: Performance with larger data
    print("\nTest 4: Performance with Larger Data")
    
    # Create larger dataset
    def random_dates(n):
        """Generate n random dates between 2000 and 2023"""
        start_date = datetime(2000, 1, 1)
        end_date = datetime(2023, 12, 31)
        delta = end_date - start_date
        delta_seconds = delta.total_seconds()
        
        return [
            start_date + timedelta(seconds=random.uniform(0, delta_seconds))
            for _ in range(n)
        ]
    
    # Create 10 columns with 100 dates each
    num_columns = 10
    dates_per_column = 100
    large_columns = [pd.Series(random_dates(dates_per_column)) for _ in range(num_columns)]
    large_configs = [{} for _ in range(num_columns)]
    
    # Time batch processing
    start_time = time.time()
    large_batch_vectors = vectorizer.vectorize_batch(large_columns, large_configs)
    large_batch_time = time.time() - start_time
    
    # Time sequential processing
    start_time = time.time()
    large_sequential_vectors = [vectorizer.vectorize(col, {}) for col in large_columns]
    large_sequential_time = time.time() - start_time
    
    print(f"Large dataset - {num_columns} columns with {dates_per_column} dates each")
    print(f"Sequential vectorization time: {large_sequential_time:.4f}s")
    print(f"Batch vectorization time: {large_batch_time:.4f}s")
    print(f"Speed ratio: {large_sequential_time / large_batch_time:.2f}x")
    
    # Batch inverse vectorization timing
    start_time = time.time()
    large_batch_reconstructed = vectorizer.inverse_vectorize_batch(large_batch_vectors, large_configs)
    large_batch_inverse_time = time.time() - start_time
    
    print(f"Batch inverse vectorization time: {large_batch_inverse_time:.4f}s")
