import os
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import (
    AutoTokenizer,
    AutoConfig,
    BertModel,
    DataCollatorWithPadding,
    get_scheduler,
)
import numpy as np

class OnlineCostModelTrainer:
    def __init__(self, model, tokenizer, orig_model_names, arm_to_llm, 
                 cost_per_token, input_cost_per_token, device, checkpoint_dir, lr=1e-6, 
                 update_freq=10):
        """
        Args:
            model: Your BertRegressionModel
            device: cuda/cpu
            lr: Learning rate for online updates (smaller than initial training)
            update_freq: Update model every N observations
        """
        self.model = model
        self.tokenizer = tokenizer
        self.orig_model_names = orig_model_names
        self.arm_to_llm = arm_to_llm
        self.cost_per_token = cost_per_token
        self.input_cost_per_token = input_cost_per_token
        self.device = device
        self.optimizer = optim.AdamW(model.parameters(), lr=lr)
        self.criterion = nn.L1Loss()
        self.update_freq = update_freq

        self.checkpoint_dir = checkpoint_dir
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        # Buffer to accumulate observations
        self.buffer = {
            'prompts': [],
            'model_indices': [],
            'model_names': [],  # Track which arm
            'actual_lengths': [],
            'predicted_lengths': []  # Store predictions before update
        }
        self.update_count = 0
        
        # ===== NEW: Comprehensive Accuracy Tracking =====
        self.accuracy_stats = {
            'overall': {
                'predictions': [],
                'actuals': [],
                'errors': [],
                'absolute_errors': [],
                'relative_errors': [],
                'squared_errors': []
            }
        }
        
        # Per-model tracking
        self.per_model_stats = {}
        
        # Track accuracy before and after updates
        self.pre_update_errors = []
        self.post_update_errors = []
        
        # Track metrics over time
        self.metrics_history = {
            'round': [],
            'mae': [],  # Mean Absolute Error
            'rmse': [],  # Root Mean Squared Error
            'mape': [],  # Mean Absolute Percentage Error
            'r2': []  # R-squared
        }
        
    def add_observation(self, prompt, model_name, actual_output_length, 
                       predicted_length):
        """
        Store an observation for later batch update
        
        Args:
            prompt: Input prompt text
            model_name: Arm name (e.g., 'base', 'finetune_med')
            actual_output_length: Observed token count
            predicted_length: Model's prediction for this prompt
        """
        self.buffer['prompts'].append(prompt)
        self.buffer['model_names'].append(model_name)
        self.buffer['model_indices'].append(
            self.orig_model_names.index(self.arm_to_llm[model_name])
        )
        self.buffer['actual_lengths'].append(actual_output_length)
        self.buffer['predicted_lengths'].append(predicted_length)
        
        # Update accuracy statistics
        self._update_accuracy_stats(
            model_name, 
            predicted_length, 
            actual_output_length
        )
        
        # Update model when buffer is full
        if len(self.buffer['prompts']) >= self.update_freq:
            self.update_model()
    
    def _update_accuracy_stats(self, model_name, predicted, actual):
        """Update running accuracy statistics"""
        error = predicted - actual
        abs_error = abs(error)
        rel_error = abs_error / max(actual, 1.0)  # Avoid division by zero
        sq_error = error ** 2
        
        # Overall stats
        self.accuracy_stats['overall']['predictions'].append(predicted)
        self.accuracy_stats['overall']['actuals'].append(actual)
        self.accuracy_stats['overall']['errors'].append(error)
        self.accuracy_stats['overall']['absolute_errors'].append(abs_error)
        self.accuracy_stats['overall']['relative_errors'].append(rel_error)
        self.accuracy_stats['overall']['squared_errors'].append(sq_error)
        
        # Per-model stats
        if model_name not in self.per_model_stats:
            self.per_model_stats[model_name] = {
                'predictions': [],
                'actuals': [],
                'errors': [],
                'absolute_errors': [],
                'relative_errors': [],
                'squared_errors': [],
                'count': 0
            }
        
        stats = self.per_model_stats[model_name]
        stats['predictions'].append(predicted)
        stats['actuals'].append(actual)
        stats['errors'].append(error)
        stats['absolute_errors'].append(abs_error)
        stats['relative_errors'].append(rel_error)
        stats['squared_errors'].append(sq_error)
        stats['count'] += 1
    
    def update_model(self):
        """Perform one gradient step on accumulated observations"""
        if len(self.buffer['prompts']) == 0:
            return
        
        # Calculate pre-update error on this batch
        pre_errors = [
            abs(self.buffer['predicted_lengths'][i] - 
                self.buffer['actual_lengths'][i])
            for i in range(len(self.buffer['prompts']))
        ]
        self.pre_update_errors.extend(pre_errors)
        
        self.model.train()
        
        # Tokenize batch of prompts
        toks = self.tokenizer(
            self.buffer['prompts'],
            truncation=True,
            padding="max_length",
            max_length=256,
            return_tensors="pt"
        ).to(self.device)
        
        # Create one-hot encodings
        batch_size = len(self.buffer['prompts'])
        onehots = torch.zeros(batch_size, len(self.orig_model_names), 
                             device=self.device)
        for i, idx in enumerate(self.buffer['model_indices']):
            onehots[i, idx] = 1.0
            
        # Actual lengths as labels
        labels = torch.tensor(self.buffer['actual_lengths'], 
                             dtype=torch.float32, device=self.device)
        
        # Forward pass and update
        self.optimizer.zero_grad()
        preds = self.model(
            toks["input_ids"],
            toks["attention_mask"],
            onehots
        )
        loss = self.criterion(preds, labels)
        loss.backward()
        
        # Gradient clipping to prevent instability
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        
        self.optimizer.step()
        
        # Calculate post-update error on this batch
        self.model.eval()
        with torch.no_grad():
            new_preds = self.model(
                toks["input_ids"],
                toks["attention_mask"],
                onehots
            ).cpu().numpy()
        
        post_errors = [
            abs(new_preds[i] - self.buffer['actual_lengths'][i])
            for i in range(len(self.buffer['prompts']))
        ]
        self.post_update_errors.extend(post_errors)
        
        self.update_count += 1
        avg_pre = np.mean(pre_errors)
        avg_post = np.mean(post_errors)
        print(f"  [Cost Model Update #{self.update_count}] "
              f"Loss: {loss.item():.4f} | "
              f"Pre-update MAE: {avg_pre:.2f} | "
              f"Post-update MAE: {avg_post:.2f}")
        
        # Clear buffer
        self.buffer = {
            'prompts': [],
            'model_indices': [],
            'model_names': [],
            'actual_lengths': [],
            'predicted_lengths': []
        }
        
        self.model.eval()
    
    def compute_metrics(self, round_num=None):
        """
        Compute comprehensive accuracy metrics
        
        Returns:
            dict: Dictionary of accuracy metrics
        """
        overall = self.accuracy_stats['overall']
        
        if len(overall['actuals']) == 0:
            return {}
        
        preds = np.array(overall['predictions'])
        actuals = np.array(overall['actuals'])
        
        # Mean Absolute Error
        mae = np.mean(overall['absolute_errors'])
        
        # Root Mean Squared Error
        rmse = np.sqrt(np.mean(overall['squared_errors']))
        
        # Mean Absolute Percentage Error
        mape = np.mean(overall['relative_errors']) * 100
        
        # R-squared
        ss_res = np.sum(overall['squared_errors'])
        ss_tot = np.sum((actuals - np.mean(actuals)) ** 2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
        
        # Bias (mean error)
        bias = np.mean(overall['errors'])
        
        metrics = {
            'mae': mae,
            'rmse': rmse,
            'mape': mape,
            'r2': r2,
            'bias': bias,
            'n_samples': len(actuals)
        }
        
        # Track over time
        if round_num is not None:
            self.metrics_history['round'].append(round_num)
            self.metrics_history['mae'].append(mae)
            self.metrics_history['rmse'].append(rmse)
            self.metrics_history['mape'].append(mape)
            self.metrics_history['r2'].append(r2)
        
        return metrics
    
    def compute_per_model_metrics(self):
        """
        Compute metrics for each model separately
        
        Returns:
            dict: Dictionary mapping model names to their metrics
        """
        results = {}
        
        for model_name, stats in self.per_model_stats.items():
            if stats['count'] == 0:
                continue
            
            preds = np.array(stats['predictions'])
            actuals = np.array(stats['actuals'])
            
            mae = np.mean(stats['absolute_errors'])
            rmse = np.sqrt(np.mean(stats['squared_errors']))
            mape = np.mean(stats['relative_errors']) * 100
            
            ss_res = np.sum(stats['squared_errors'])
            ss_tot = np.sum((actuals - np.mean(actuals)) ** 2)
            r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
            
            bias = np.mean(stats['errors'])
            
            results[model_name] = {
                'mae': mae,
                'rmse': rmse,
                'mape': mape,
                'r2': r2,
                'bias': bias,
                'n_samples': stats['count']
            }
        
        return results
    
    def print_summary(self):
        """Print a comprehensive summary of cost model accuracy"""
        print("\n" + "="*70)
        print("COST MODEL ACCURACY SUMMARY")
        print("="*70)
        
        # Overall metrics
        overall_metrics = self.compute_metrics()
        print(f"\n Overall Metrics (n={overall_metrics['n_samples']}):")
        print(f"  MAE (tokens):  {overall_metrics['mae']:.2f}")
        print(f"  RMSE (tokens): {overall_metrics['rmse']:.2f}")
        print(f"  MAPE:          {overall_metrics['mape']:.2f}%")
        print(f"  R2:            {overall_metrics['r2']:.4f}")
        print(f"  Bias (tokens): {overall_metrics['bias']:.2f}")
        
        # Per-model metrics
        per_model = self.compute_per_model_metrics()
        print(f"\n Per-Model Metrics:")
        print(f"  {'Model':<20} {'MAE':<10} {'MAPE':<10} {'R2':<10} {'N':<8}")
        print(f"  {'-'*20} {'-'*10} {'-'*10} {'-'*10} {'-'*8}")
        
        for model_name in sorted(per_model.keys()):
            metrics = per_model[model_name]
            print(f"  {model_name:<20} "
                  f"{metrics['mae']:<10.2f} "
                  f"{metrics['mape']:<10.2f}% "
                  f"{metrics['r2']:<10.4f} "
                  f"{metrics['n_samples']:<8}")
        
        # Online learning improvement
        if self.pre_update_errors and self.post_update_errors:
            pre_mean = np.mean(self.pre_update_errors)
            post_mean = np.mean(self.post_update_errors)
            improvement = ((pre_mean - post_mean) / pre_mean) * 100
            
            print(f"\n Online Learning Impact:")
            print(f"  Pre-update MAE:  {pre_mean:.2f} tokens")
            print(f"  Post-update MAE: {post_mean:.2f} tokens")
            print(f"  Improvement:     {improvement:+.2f}%")
            print(f"  Updates applied: {self.update_count}")
        
        print("="*70 + "\n")
    
    def save_stats(self, filepath_prefix):
        """Save all statistics to pickle files"""
        import pickle
        
        # Save overall metrics
        pickle.dump(
            self.accuracy_stats,
            open(f"{filepath_prefix}_accuracy_stats.pkl", "wb")
        )
        
        # Save per-model stats
        pickle.dump(
            self.per_model_stats,
            open(f"{filepath_prefix}_per_model_stats.pkl", "wb")
        )
        
        # Save metrics history
        pickle.dump(
            self.metrics_history,
            open(f"{filepath_prefix}_metrics_history.pkl", "wb")
        )
        
        # Save summary metrics
        summary = {
            'overall': self.compute_metrics(),
            'per_model': self.compute_per_model_metrics(),
            'online_learning': {
                'pre_update_mae': np.mean(self.pre_update_errors) 
                    if self.pre_update_errors else None,
                'post_update_mae': np.mean(self.post_update_errors)
                    if self.post_update_errors else None,
                'n_updates': self.update_count
            }
        }
        pickle.dump(
            summary,
            open(f"{filepath_prefix}_summary.pkl", "wb")
        )
    
    def final_update(self):
        """Update with remaining buffer observations at end of run"""
        if len(self.buffer['prompts']) > 0:
            self.update_model()
			
    def save_checkpoint(self, run_id, round_num=None, is_final=False):

        if is_final:
            filename = f"run_{run_id}_final.pth"
        elif round_num is not None:
            filename = f"run_{run_id}_round_{round_num}.pth"
        else:
            filename = f"run_{run_id}_latest.pth"
        
        filepath = os.path.join(self.checkpoint_dir, filename)
        
        checkpoint = {
            # Model state
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            
            # Training progress
            'run_id': run_id,
            'round_num': round_num,
            'update_count': self.update_count,
            
            # Accuracy tracking
            'accuracy_stats': self.accuracy_stats,
            'per_model_stats': self.per_model_stats,
            'pre_update_errors': self.pre_update_errors,
            'post_update_errors': self.post_update_errors,
            'metrics_history': self.metrics_history,
            
            # Buffer (in case you want to resume mid-update)
            'buffer': self.buffer,
            
            # Configuration
            'config': {
                'lr': self.optimizer.param_groups[0]['lr'],
                'update_freq': self.update_freq,
                'orig_model_names': self.orig_model_names,
                'arm_to_llm': self.arm_to_llm,
                'cost_per_token': self.cost_per_token,
                'input_cost_per_token': self.input_cost_per_token
            }
        }
        
        torch.save(checkpoint, filepath)
        print(f" Saved cost model checkpoint: {filepath}")
        return filepath
    
    def load_checkpoint(self, filepath, device=None):
        """
        Load model checkpoint and restore all training state
        
        Args:
            filepath: Path to checkpoint file
            device: Device to load model on (if None, uses self.device)
        
        Returns:
            dict: Checkpoint metadata (run_id, round_num, etc.)
        """
        if device is None:
            device = self.device
        
        checkpoint = torch.load(filepath, map_location=device)
        
        # Restore model and optimizer
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Restore training progress
        self.update_count = checkpoint['update_count']
        
        # Restore accuracy tracking
        self.accuracy_stats = checkpoint['accuracy_stats']
        self.per_model_stats = checkpoint['per_model_stats']
        self.pre_update_errors = checkpoint['pre_update_errors']
        self.post_update_errors = checkpoint['post_update_errors']
        self.metrics_history = checkpoint['metrics_history']
        
        # Restore buffer
        self.buffer = checkpoint['buffer']
        
        # Restore configuration
        config = checkpoint['config']
        self.update_freq = config['update_freq']
        self.orig_model_names = config['orig_model_names']
        self.arm_to_llm = config['arm_to_llm']
        self.cost_per_token = config['cost_per_token']
        self.input_cost_per_token = config['input_cost_per_token']
        
        print(f"  Loaded cost model checkpoint from: {filepath}")
        print(f"   Run: {checkpoint['run_id']}, Round: {checkpoint['round_num']}, "
              f"Updates: {self.update_count}")
        
        return {
            'run_id': checkpoint['run_id'],
            'round_num': checkpoint['round_num'],
            'update_count': checkpoint['update_count']
        }
    
    def cleanup_intermediate_checkpoints(self, run_id):
        """
        Delete all intermediate checkpoints for a run, keeping only the final one
        
        Args:
            run_id: Run number to clean up checkpoints for
        """
        import glob
        
        # Pattern to find all round checkpoints (not final)
        pattern = os.path.join(self.checkpoint_dir, f"run_{run_id}_round_*.pth")
        intermediate_checkpoints = glob.glob(pattern)
        
        # Also check for "latest" checkpoint
        latest_path = os.path.join(self.checkpoint_dir, f"run_{run_id}_latest.pth")
        if os.path.exists(latest_path):
            intermediate_checkpoints.append(latest_path)
        
        # Delete all intermediate checkpoints
        deleted_count = 0
        for cp_path in intermediate_checkpoints:
            try:
                os.remove(cp_path)
                deleted_count += 1
            except Exception as e:
                print(f"  Warning: Could not delete {cp_path}: {e}")
        
        if deleted_count > 0:
            print(f"Cleaned up {deleted_count} intermediate checkpoint(s) for run {run_id}")
        
        
    
    @staticmethod
    def get_latest_checkpoint(run_id, checkpoint_dir="cost_model_checkpoints"):
        """
        Find the latest checkpoint for a given run
        
        Args:
            run_id: Run number to find checkpoint for
            checkpoint_dir: Directory containing checkpoints
        
        Returns:
            str: Path to latest checkpoint, or None if not found
        """
        import glob
        
        # Look for final checkpoint first
        final_path = os.path.join(checkpoint_dir, f"run_{run_id}_final.pth")
        if os.path.exists(final_path):
            return final_path
        
        # Look for round checkpoints
        pattern = os.path.join(checkpoint_dir, f"run_{run_id}_round_*.pth")
        checkpoints = glob.glob(pattern)
        
        if checkpoints:
            # Extract round numbers and find the latest
            rounds = []
            for cp in checkpoints:
                try:
                    round_num = int(cp.split('_round_')[1].split('.pth')[0])
                    rounds.append((round_num, cp))
                except:
                    continue
            
            if rounds:
                rounds.sort(reverse=True)
                return rounds[0][1]
        
        # Look for latest checkpoint
        latest_path = os.path.join(checkpoint_dir, f"run_{run_id}_latest.pth")
        if os.path.exists(latest_path):
            return latest_path
        
        return None