import os
import sys
import torch
import gpytorch
import numpy as np
from pathlib import Path
from typing import Dict, Any
import json

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import lib
import delu
from tune import DirectTunerMixin
from data.data_processor import load_and_preprocess_dataset
from lib.metrics import evaluate_model

__all__ = ["GaussianProcessModule"]

class GPClassificationModel(gpytorch.models.ApproximateGP):
    def __init__(self, train_x, kernel='rbf', lengthscale=1.0):
        inducing_points = train_x[:min(500, train_x.size(0))]  
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(0)
        )
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution,
            learn_inducing_locations=True
        )
        super(GPClassificationModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        
        if kernel == 'rbf':
            self.covar_module = gpytorch.kernels.RBFKernel()
        elif kernel == 'matern':
            self.covar_module = gpytorch.kernels.MaternKernel(nu=1.5)
        elif kernel == 'rational_quadratic':
            self.covar_module = gpytorch.kernels.RQKernel()
        else:
            self.covar_module = gpytorch.kernels.RBFKernel()
        
        self.covar_module.lengthscale = lengthscale
        self.covar_module = gpytorch.kernels.ScaleKernel(self.covar_module)
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class GaussianProcessModule(DirectTunerMixin):
    def __init__(
        self,
        *,
        kernel: str = 'rbf',
        length_scale: float = 1.0,
        length_scale_bounds: tuple = (1e-5, 1e5), 
        alpha: float = 1e-5, 
        learning_rate: float = 0.1,
        training_iterations: int = 100,
        multi_class: str = 'one_vs_rest',  
        random_state: int = 42,
        **kwargs
    ) -> None:
        self.kernel = kernel
        self.length_scale = length_scale
        self.length_scale_bounds = length_scale_bounds
        self.alpha = alpha
        self.learning_rate = learning_rate
        self.training_iterations = training_iterations
        self.multi_class = multi_class
        self.random_state = random_state
        
        self.device = lib.get_device()
        self.model = None
        self.likelihood = None
        self.is_classification = None
        self.is_binary = None
        self.num_classes = None
        self.classes_ = None

    def create_model_from_params(self, params):
        return GaussianProcessModule(**params)

    def get_optimization_target(self, stats):
        metrics = stats.get('metrics', {}).get(lib.VAL, {})
        return metrics.get('score', -float('inf'))

    def _train_and_evaluate(self, model, config, trial_number, output_dir):
        timer = delu.Timer()

        output = Path(output_dir)
        output.mkdir(parents=True, exist_ok=True)

        torch.manual_seed(config.get('seed', 0))
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(config.get('seed', 0))

        dataset_id = config['data'].get('dataset_id')
        delu.random.seed(config.get('seed', 0))
        data = load_and_preprocess_dataset(
            dataset_id=dataset_id,
            config=config,
            device=self.device
        )

        X_train = data['X_train']
        X_val = data['X_val'] 
        X_test = data['X_test']
        y_train = data['y_train'].squeeze()
        y_val = data['y_val'].squeeze()
        y_test = data['y_test'].squeeze()
        
        input_dim = data['input_dim']
        self.num_classes = data['num_classes']
        dataset_info = data['dataset_info']
        self.classes_ = data['y_candidates'].cpu().numpy()
        self.is_classification = True
        self.is_binary = self.num_classes == 2

        training_config = config.get('training', {})
        model_config = config.get('model', {})

        kernel = config.get('kernel') or model_config.get('kernel', 'rbf')
        length_scale = config.get('length_scale') or model_config.get('length_scale', 1.0)
        alpha = config.get('alpha') or model_config.get('alpha', 1e-5)
        learning_rate = config.get('learning_rate') or training_config.get('learning_rate', 0.1)
        training_iterations = config.get('training_iterations') or training_config.get('training_iterations', 100)

        length_scale = max(length_scale, 1e-5)
        alpha = max(alpha, 1e-8)
        learning_rate = max(learning_rate, 0.001)
        training_iterations = max(50, min(training_iterations, 500))

        if self.is_binary:
            print(f"Training binary classifier with {y_train.shape[0]} samples...")
            
            self.likelihood = gpytorch.likelihoods.BernoulliLikelihood().to(self.device)
            self.model = GPClassificationModel(
                X_train, kernel=kernel, lengthscale=length_scale
            ).to(self.device)
            
            self.model.train()
            self.likelihood.train()
            
            optimizer = torch.optim.Adam([
                {'params': self.model.parameters()},
            ], lr=learning_rate)
            
            mll = gpytorch.mlls.VariationalELBO(
                self.likelihood, self.model, 
                num_data=X_train.size(0)
            )
            
            for i in range(training_iterations):
                optimizer.zero_grad()
                model_prediction = self.model(X_train)
                loss = -mll(model_prediction, y_train.float())
                loss.backward()
                if i % 20 == 0:
                    print(f'Iteration {i+1}/{training_iterations} - Loss: {loss.item():.3f}')
                optimizer.step()

            self.model.eval()
            self.likelihood.eval()

            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                log_marginal_likelihood = mll(model_prediction, y_train.float()).item()
            optimized_length_scale = self.model.covar_module.base_kernel.lengthscale.item()
            
        else:
            self.models = []
            self.likelihoods = []
            log_marginal_likelihood = 0.0
            optimized_length_scale = 0.0
            
            for c in range(self.num_classes):
                binary_y = (y_train == c).float()
                
                likelihood = gpytorch.likelihoods.BernoulliLikelihood().to(self.device)
                model = GPClassificationModel(
                    X_train, kernel=kernel, lengthscale=length_scale
                ).to(self.device)
                
                model.train()
                likelihood.train()
                
                optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=learning_rate)
                mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=X_train.size(0))
                
                for i in range(training_iterations):
                    optimizer.zero_grad()
                    loss = -mll(model_prediction, binary_y)
                    loss.backward()
                    optimizer.step()
                    if i % 20 == 0 and c == 0:
                        print(f'Class {c+1}/{self.num_classes} - Iteration {i+1}/{training_iterations}')
                
                model.eval()
                likelihood.eval()
                
                self.models.append(model)
                self.likelihoods.append(likelihood)
                
                with torch.no_grad(), gpytorch.settings.fast_pred_var():
                    log_marginal_likelihood += mll(model_prediction, binary_y).item()
                    optimized_length_scale += model.covar_module.base_kernel.lengthscale.item()
            
            log_marginal_likelihood /= self.num_classes
            optimized_length_scale /= self.num_classes

        metrics = evaluate_model(X_val, y_val, X_test, y_test)
        stats = {
            'algorithm': 'GPyTorchGaussianProcessClassifier',
            'dataset': dataset_info.get('name', f'dataset_{dataset_id}'),
            'metrics': metrics,
            'trial_number': trial_number,
            'time': lib.format_seconds(timer()),
            'model_info': {
                'n_features': input_dim,
                'n_samples_train': y_train.shape[0],
                'kernel': kernel,
                'length_scale': length_scale,
                'alpha': alpha,
                'learning_rate': learning_rate,
                'training_iterations': training_iterations,
                'log_marginal_likelihood': log_marginal_likelihood,
                'optimized_length_scale': optimized_length_scale,
                'device': str(self.device),
            }
        }
        
        output_path = Path(str(output))
        output_path.mkdir(parents=True, exist_ok=True)
        model_file = output_path / 'best_model.pt'
        
        if self.is_binary:
            model_data = {
                'model_state_dict': self.model.state_dict(),
                'likelihood_state_dict': self.likelihood.state_dict(),
                'classes_': self.classes_,
                'is_binary': self.is_binary,
                'num_classes': self.num_classes,
                'kernel': kernel,
                'lengthscale': length_scale,
                'device': str(self.device),
                'model_config': {
                    'kernel': kernel,
                    'length_scale': length_scale,
                    'alpha': alpha,
                    'learning_rate': learning_rate,
                    'training_iterations': training_iterations,
                }
            }
        else:
            model_data = {
                'model_state_dicts': [m.state_dict() for m in self.models],
                'likelihood_state_dicts': [l.state_dict() for l in self.likelihoods],
                'classes_': self.classes_,
                'is_binary': self.is_binary,
                'num_classes': self.num_classes,
                'kernel': kernel,
                'lengthscale': length_scale,
                'device': str(self.device),
                'model_config': {
                    'kernel': kernel,
                    'length_scale': length_scale,
                    'alpha': alpha,
                    'learning_rate': learning_rate,
                    'training_iterations': training_iterations,
                }
            }
        
        torch.save(model_data, str(model_file))
        print(f"Model saved to: {model_file}")
        
        stats_file = output_path / 'stats.json'
        lib.dump_json(stats, str(stats_file), indent=4)
        print(f"Stats saved to: {stats_file}")
        
        return stats
                   
def main():
    import argparse
    parser = argparse.ArgumentParser(description='Run GPyTorch Gaussian Process Classifier')
    parser.add_argument('config_path', type=str, help='Path to configuration TOML file')
    args = parser.parse_args()

    config = lib.load_toml(args.config_path)

    config_path = Path(args.config_path)
    if config_path.stem.startswith('trial_'):
        output_dir = config_path.parent / config_path.stem
        print(f"Optuna mode: Config path: {config_path}")
        print(f"Optuna mode: Output dir: {output_dir}")
    else:
        output_dir = config_path.parent / 'trial_0'
        print(f"Standalone mode: Output dir: {output_dir}")

    model = GaussianProcessModule()
    stats = model._train_and_evaluate(
        model=model,
        config=config,
        trial_number=0,
        output_dir=output_dir
    )
   
    print("Training completed. Stats:")
    print(json.dumps(stats, indent=4))


if __name__ == "__main__":
    main()