import sys
PROJECT_PATH = "./Data_Pattern_Learnability"  # Absolute path to the project directory
sys.path.append(PROJECT_PATH)  # Add the project path to sys.path

import yaml
import json
import logging
from datetime import datetime
import numpy as np
import torch
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse

from utils.estimators import evoPredEstimator
from utils.process.process import AutoRegressiveGenerator, MultivariateMarkovGenerator

class SequenceExperimentManager:
    def __init__(self, config_path):
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)
        
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        print(f"Using device: {self.device}")

        # Set up folders and logging
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        experiment_name = self.config['experiment_name']
        self.exp_dir = Path(f"{PROJECT_PATH}/results_data/{experiment_name}_{timestamp}")
        self.results_dir = self.exp_dir / "results"
        self.plots_dir = self.exp_dir / "plots"
        
        for dir in [self.exp_dir, self.results_dir, self.plots_dir]:
            dir.mkdir(parents=True, exist_ok=True)
        
        # Set up logger with file and console handlers
        self.logger = logging.getLogger("SequenceManager")
        self.logger.setLevel(logging.INFO)
        
        fh = logging.FileHandler(self.exp_dir / "experiment.log")
        fh.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
        
        ch = logging.StreamHandler()
        ch.setFormatter(logging.Formatter('%(message)s'))
        
        self.logger.addHandler(fh)
        self.logger.addHandler(ch)

    def create_sequence_generator(self, data_config, **kwargs):
        # Create the appropriate sequence generator
        if data_config["type"] == "autoregressive":
            return AutoRegressiveGenerator(
                batch_size=self.config['training']['batch_size'],
                p=kwargs['p'], 
                rho=kwargs['rho'],
                dim=kwargs['dim'],
            )
        elif data_config["type"] == "multivariate_markov":
            return MultivariateMarkovGenerator(
                batch_size=self.config['training']['batch_size'],
                correlation_matrix=kwargs['correlation_matrix'],
                dim=kwargs['dim'], 
            )

    def run_experiments(self):
        # Run all sequence-based experiments
        all_results = []
        
        self.logger.info("\n" + "="*50)
        self.logger.info("Starting Sequence-based Experiments")
        self.logger.info("="*50)
        
        for data_config in self.config['data_configs']:
            if not data_config.get('enabled', True):
                continue
                
            self.logger.info(f"\nProcessing {data_config['type'].upper()}")
            self.logger.info(f"Configuration: {data_config['params']}")
            
            if data_config['type'] == "autoregressive":
                for dim in data_config['params']['dimensions']:
                    for p in data_config['params']['p']:
                        for rho in data_config['params']['rho']:
                            result = self._run_sequence_experiment(
                                data_config=data_config,
                                dim=dim,
                                p=p,
                                rho=rho,
                                sequence_length=data_config['sequence_length']
                            )
                            all_results.append(result)

        self.save_results(all_results)

    def _run_sequence_experiment(self, data_config, **kwargs):
        # Run experiment with long sequence and estimation windows
        self.logger.info("\nStarting new sequence experiment:")
        self.logger.info(f"Type: {data_config['type']}")
        self.logger.info(f"Parameters: {kwargs}")
        
        # Generate sequence
        generator = self.create_sequence_generator(data_config, **kwargs)
        sequence = generator.generate_long_array(kwargs['sequence_length'])

        if data_config['type'] == "autoregressive":
            filename = f"sequence_{data_config['type']}_dim{kwargs['dim']}_p{kwargs['p']}_rho{kwargs['rho']}.npy"
        else:
            filename = f"sequence_{data_config['type']}_dim{kwargs['dim']}.npy"
    
        save_path = self.results_dir / filename
        np.save(save_path, sequence.cpu().numpy() if isinstance(sequence, torch.Tensor) else sequence)
    
        self.logger.info(f"Sequence saved: {save_path}")
        self.logger.info(f"Generated sequence of length {kwargs['sequence_length']}")
        
        exp_id = self._create_experiment_id(data_config, kwargs)
        self.logger.info(f"Experiment ID: {exp_id}")
        
        results = {
            'type': data_config['type'],
            'parameters': kwargs,
            'sequence_length': kwargs['sequence_length'],
            'window_estimates': {}
        }
        
        # Loop over each estimation window (k, k')
        for window in self.config['estimation_windows']:
            window_id = f"k={window['k']}_kprime={window['k_prime']}"
            self.logger.info(f"\nEstimating with window {window_id}")
            
            plt.figure(figsize=(12, 6))
            plt.title(f"Sequence Estimation\n{exp_id}\n{window_id}")
            
            window_results = {}
            
            # Train each estimator on this window
            for estimator_config in self.config['estimators']:
                if not estimator_config.get('enabled', True):
                    continue
                
                self.logger.info(f"\nTraining {estimator_config['name']}")
                self.logger.info(f"Type: {estimator_config['estimator_type']}")
                self.logger.info(f"Critic: {estimator_config['critic_type']}")
                
                estimator = evoPredEstimator(
                    dim_data=kwargs['dim'],
                    k=window['k'], 
                    k_prime=window['k_prime'],
                    batch_size=self.config['training']['batch_size'],
                    estimator=estimator_config['estimator_type'],
                    type_of_critic=estimator_config['critic_type'], 
                    device=self.device,
                    **estimator_config.get('critic_params', {})
                )
                
                estimates = estimator.train_estimator_on_large_sequence(
                    sequence=sequence,
                    iterations=self.config['training']['iterations'],
                    lr=self.config['training']['learning_rate']
                )
                
                final_estimate = float(np.mean(estimates[-100:]))
                
                self.logger.info(f"Final estimate: {final_estimate:.4f}")
                
                plt.plot(estimates, label=estimator_config['name'])
                
                window_results[estimator_config['name']] = {
                    'all_estimates': estimates,
                    'final_estimate': final_estimate,
                    'estimator_type': estimator_config['estimator_type'],
                    'critic_type': estimator_config['critic_type'],
                    'critic_params': estimator_config.get('critic_params', {})
                }
            
            self._save_plot(plt, f"{exp_id}_{window_id}")
            results['window_estimates'][window_id] = window_results
        
        return results

    def _create_experiment_id(self, data_config, kwargs):
        # Create a unique experiment identifier
        exp_id = f"{data_config['type']}_dim={kwargs['dim']}"
        if data_config['type'] == "autoregressive":
            exp_id += f"_p={kwargs['p']}_rho={kwargs['rho']}"
        return exp_id

    def _save_plot(self, plt, exp_id):
        # Save and close current plot
        plt.xlabel('Iteration')
        plt.ylabel('MI Estimate')
        plt.legend()
        plt.grid(True)
        plt.savefig(self.plots_dir / f"{exp_id}.png")
        plt.close()

    def save_results(self, results):
        # Save results and summary
        with open(self.results_dir / "detailed_results.json", 'w') as f:
            json.dump(results, f, indent=4)
        
        summary_data = []
        for result in results:
            base_info = {
                'type': result['type'],
                'sequence_length': result['sequence_length'],
                **result['parameters']
            }
            
            for window_id, window_results in result['window_estimates'].items():
                for est_name, est_results in window_results.items():
                    k, k_prime = map(int, window_id.replace('k=', '').replace('kprime=', '').split('_'))
                    summary_data.append({
                        **base_info,
                        'k': k,
                        'k_prime': k_prime,
                        'estimator': est_name,
                        'estimator_type': est_results['estimator_type'],
                        'critic_type': est_results['critic_type'],
                        'final_estimate': est_results['final_estimate']
                    })
        
        df = pd.DataFrame(summary_data)
        df.to_csv(self.results_dir / "summary_results.csv", index=False)
        
        with open(self.exp_dir / "config.yaml", 'w') as f:
            yaml.dump(self.config, f)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Run sequence-based experiments')
    parser.add_argument('--config', type=str, 
                      default=f'{PROJECT_PATH}/configs/ar_sequence_config.yaml',
                      help='Path to config file')
    args = parser.parse_args()
    
    manager = SequenceExperimentManager(args.config)
    try:
        manager.run_experiments()
        manager.logger.info("\nAll experiments completed successfully!")
    except Exception as e:
        manager.logger.error(f"Error during experiments: {str(e)}")
        raise e
