import sys
PROJECT_PATH = "./Data_Pattern_Learnability"  # Absolute path to the project directory
sys.path.append(PROJECT_PATH)  # Add the project path to sys.path

import numpy as np
import matplotlib.pyplot as plt
from utils.process.data_spins import GeneratingData, compute_word_entropy
from utils.estimators import evoPredEstimator
import os
import json
from tqdm import tqdm 
from datetime import datetime
from scipy.stats import linregress

block_size = 100000

def computing_entropy(force_recompute=False):
    
    results_path = f"{PROJECT_PATH}/results_data/spin_xps/entropy_results_{block_size}.json"
    
    if os.path.exists(results_path) and not force_recompute:
        print("Loading existing results...")
        with open(results_path, 'r') as f:
            results = json.load(f)
        
        word_lengths = results['word_lengths']
        entropy_empirical = results['entropy_empirical']
        evoRate_empirical = results['evoRate_empirical']
        predictive_info_1 = np.array(results['predictive_info_1'])  # Convert list to numpy array

        return word_lengths, entropy_empirical, predictive_info_1
        
    else:
        print("Computing new results...")
        print("Starting generating the chain")
        data_generator = GeneratingData(chain_length=10**7)
        chain_1_bin = data_generator.generate_chain_2(block_size=block_size)
        word_lengths = [i for i in range(1, 20)] 
        entropy_empirical = []
        evoRate_empirical = []
        
        S0 = compute_word_entropy(chain_1_bin, 1)
        for length_word in word_lengths:
            print(f"Processing word length: {length_word}")
            S1 = compute_word_entropy(chain_1_bin, length_word)
            entropy_empirical.append(float(S1))  # Convert to float for JSON serialization
            evorate = S1 + S0 - compute_word_entropy(chain_1_bin, length_word + 1)
            evoRate_empirical.append(float(evorate))  # Convert to float for JSON serialization

        k_array = np.array([int(k) for k in word_lengths])
        S = np.array(entropy_empirical)
        S_corrected = S - 0.5 * np.log(k_array)
        slope, intercept, r_value, p_value, std_err = linregress(k_array, S_corrected)
        S0_1 = slope
        predictive_info_1 = np.array(entropy_empirical) - np.array(word_lengths)*S0_1
        
        # Save results
        results = {
            'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
            'word_lengths': word_lengths,
            'entropy_empirical': entropy_empirical,
            'evoRate_empirical': evoRate_empirical,
            'predictive_info_1': predictive_info_1.tolist()  # Convert numpy array to list
        }
        
        # Create folder if it does not exist
        os.makedirs(os.path.dirname(results_path), exist_ok=True)
        
        with open(results_path, 'w') as f:
            json.dump(results, f)
        
        print("Results saved to", results_path)

    return word_lengths, entropy_empirical, evoRate_empirical, predictive_info_1


def training_a_predictive_model(k: int, model_type: str = "mlp", max_epochs: int = 1000, 
                              patience: int = 10, num_batches_per_epoch: int = 50):
    """
    Version with fixed number of batches per epoch
    """
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset, DataLoader, random_split
        
    # Dataset class
    class SpinDataset(Dataset):
        def __init__(self, data, k):
            self.data = torch.FloatTensor(data)
            self.k = k
            
        def __len__(self):
            return len(self.data) - self.k
            
        def __getitem__(self, idx):
            x = self.data[idx:idx+self.k]
            y = self.data[idx+self.k]
            return x, y
    
    # Model definitions
    class MLPPredictor(nn.Module):
        def __init__(self, k):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(k, 64),
                nn.ReLU(),
                nn.Linear(64, 32),
                nn.ReLU(),
                nn.Linear(32, 2)  # 2 classes for softmax
            )
            
        def forward(self, x):
            return self.net(x)
    
    class LSTMPredictor(nn.Module):
        def __init__(self, k):
            super().__init__()
            self.lstm = nn.LSTM(1, 32, batch_first=True)
            self.fc = nn.Linear(32, 2)
            
        def forward(self, x):
            x = x.unsqueeze(-1)  # Add dimension for feature
            out, _ = self.lstm(x)
            return self.fc(out[:, -1, :])

    class TransformerPredictor(nn.Module):
        def __init__(self, k, d_model=32, nhead=4, num_layers=2):
            super().__init__()
            self.k = k
            self.d_model = d_model
            
            # Embedding to transform 1D sequence to d_model dimensions
            self.embedding = nn.Linear(1, d_model)
            
            # Positional encoding is important for the Transformer
            self.pos_encoder = nn.Parameter(torch.randn(1, k, d_model))
            
            # Transformer layers
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=d_model*4,
                batch_first=True
            )
            self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            
            # Final classification layer
            self.fc = nn.Linear(d_model, 2)
        
        def forward(self, x):
            x = x.unsqueeze(-1)  # [batch_size, k, 1]
            x = self.embedding(x)  # [batch_size, k, d_model]
            x = x + self.pos_encoder
            x = self.transformer(x)  # [batch_size, k, d_model]
            x = x[:, -1, :]  # Use the last position for prediction
            return self.fc(x)  # [batch_size, 2]

    # Generate and prepare data
    data_generator = GeneratingData(chain_length=10**9)
    chain_1_bin = data_generator.generate_chain_2(block_size=block_size)
    
    # Create datasets
    full_dataset = SpinDataset(chain_1_bin, k)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Dataloaders
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
    
    # Initialize model
    if model_type == "mlp":
        model = MLPPredictor(k)
    elif model_type == "lstm":
        model = LSTMPredictor(k)
    elif model_type == "transformer":
        model = TransformerPredictor(k)
    else:
        raise ValueError(f"Model type {model_type} not implemented")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # For early stopping
    best_val_loss = float('inf')
    best_model = None
    patience_counter = 0
    train_losses = []
    val_losses = []
    best_epoch = 0
    
    for epoch in tqdm(range(max_epochs)):
        # Training mode
        model.train()
        train_loss = 0
        
        # Use only num_batches_per_epoch batches
        for i, (batch_x, batch_y) in enumerate(train_loader):
            if i >= num_batches_per_epoch:
                break
                
            optimizer.zero_grad()
            outputs = model(batch_x)
            labels = (batch_y + 1) // 2
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        avg_train_loss = train_loss / num_batches_per_epoch
        train_losses.append(avg_train_loss)
        
        # Evaluation mode (also using a fixed number of batches)
        model.eval()
        val_loss = 0
        val_batches = 0
        with torch.no_grad():
            for i, (batch_x, batch_y) in enumerate(val_loader):
                if i >= num_batches_per_epoch // 5:
                    break
                    
                outputs = model(batch_x)
                labels = (batch_y + 1) // 2
                loss = criterion(outputs, labels.long())
                val_loss += loss.item()
                val_batches += 1
        
        avg_val_loss = val_loss / val_batches
        val_losses.append(avg_val_loss)
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = model.state_dict()
            patience_counter = 0
            best_epoch = epoch
        else:
            patience_counter += 1
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{max_epochs}], '
                  f'Train Loss: {avg_train_loss:.4f}, '
                  f'Val Loss: {avg_val_loss:.4f}')
        
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break
    
    # Restore best model
    model.load_state_dict(best_model)
    
    return train_losses, val_losses, best_epoch


def main(force_recompute=False):
    """
    Train different models and save results to compare
    with the universal learning curve
    """
    results_path = f"{PROJECT_PATH}/results_data/spin_xps/results_exp_{block_size}.json"
    
    if os.path.exists(results_path) and not force_recompute:
        print("Loading existing results...")
        with open(results_path, 'r') as f:
            return json.load(f)
    
    # Get the universal learning curve
    word_lengths, entropy_empirical, evoRate_empirical, predictive_info_1 = computing_entropy(force_recompute=False)
    universal_learning_curve = np.diff(predictive_info_1).tolist()
    
    # Models to test
    models_config = [
        {"type": "mlp", "name": "MLP"},
        {"type": "lstm", "name": "LSTM"},
        # {"type": "transformer", "name": "Transformer"}
    ]
    
    # Training settings
    max_k = 19
    results = {
        "universal_learning_curve": universal_learning_curve,
        "entropy_empirical": entropy_empirical,
        "evoRate_empirical": evoRate_empirical,
        "word_lengths": word_lengths,
        "models": {}
    }
    
    for model_config in models_config:
        model_type = model_config["type"]
        print(f"\nTraining {model_config['name']} models...")
        model_results = {}
        
        for k in range(1, max_k + 1):
            print(f"\nTraining with k={k}")
            train_losses, val_losses, best_epoch = training_a_predictive_model(
                k=k,
                model_type=model_type,
                max_epochs=1000,
                patience=50,
                num_batches_per_epoch=50
            )
            
            model_results[f"k_{k}"] = {
                "train_losses": train_losses,
                "val_losses": val_losses,
                "best_epoch": best_epoch,
                "final_train_loss": train_losses[-1],
                "final_val_loss": val_losses[-1],
                "theoretical_optimal": val_losses[-1] - universal_learning_curve[k-1] if k-1 < len(universal_learning_curve) else None
            }
        
        results["models"][model_config["name"]] = model_results
    
    # Save results
    os.makedirs(os.path.dirname(results_path), exist_ok=True)
    with open(results_path, 'w') as f:
        json.dump(results, f)
    
    print("\nResults saved to", results_path)
    return results

if __name__ == "__main__":
    results = main(force_recompute=True)
