
import os
import sys
import json
import logging
import datetime
import subprocess
import random
from typing import Optional

import torch
import numpy as np
import transformers
import peft
from transformers import TrainerCallback, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns



def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {'accuracy': accuracy_score(labels, preds)}

def setup_tokenizer(model_name: str, token: str=None):
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, token=token)
    
    # Check what tokens need to be added
    tokens_to_add = {}
    
    if tokenizer.pad_token_id is None:
        tokens_to_add['pad_token'] = '[PAD]'
    
    if tokenizer.unk_token is None:
        tokens_to_add['unk_token'] = '[UNK]'
    
    # Add all tokens at once (more efficient)
    tokens_added = 0
    if tokens_to_add:
        tokens_added = tokenizer.add_special_tokens(tokens_to_add)
        logging.info(f"Added {tokens_added} special tokens: {list(tokens_to_add.keys())}")
    
    return tokenizer, tokens_added


class EnhancedTrainMetricsCallback(TrainerCallback):
    """
    Enhanced callback to track training metrics and save learning curves.
    """
    def __init__(self, output_dir):
        self.output_dir = output_dir
        self.train_losses = []
        self.train_accuracies = []
        self.val_losses = []
        self.val_accuracies = []
        self.epochs = []
        
    def on_epoch_end(self, args, state, control, **kwargs):
        trainer = getattr(self, 'trainer', None)
        if trainer is None:
            return control
            
        # Evaluate on training set
        train_metrics = trainer.evaluate(trainer.train_dataset, metric_key_prefix="train")
        
        # Store metrics
        self.epochs.append(int(state.epoch))
        self.train_losses.append(train_metrics['train_loss'])
        self.train_accuracies.append(train_metrics['train_accuracy'])
        
        # Get validation metrics from logs
        val_metrics = trainer.evaluate(trainer.eval_dataset, metric_key_prefix="eval")
        self.val_losses.append(val_metrics['eval_loss'])
        self.val_accuracies.append(val_metrics['eval_accuracy'])
        
        logging.info(f"Epoch {int(state.epoch)}: "
                    f"Train Loss: {train_metrics['train_loss']:.4f}, "
                    f"Train Acc: {train_metrics['train_accuracy']:.4f}, "
                    f"Val Loss: {val_metrics['eval_loss']:.4f}, "
                    f"Val Acc: {val_metrics['eval_accuracy']:.4f}")
        
        # Save learning curves
        self.save_learning_curves()
        
        return control
    
    def save_learning_curves(self):
        """Save learning curves as plots and JSON data"""
        if len(self.epochs) == 0:
            return
            
        # Save raw data
        curves_data = {
            'epochs': self.epochs,
            'train_losses': self.train_losses,
            'train_accuracies': self.train_accuracies,
            'val_losses': self.val_losses,
            'val_accuracies': self.val_accuracies
        }
        
        with open(os.path.join(self.output_dir, 'learning_curves.json'), 'w') as f:
            json.dump(curves_data, f, indent=2)
        
        # Create plots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        # Loss plot
        ax1.plot(self.epochs, self.train_losses, label='Train Loss', marker='o')
        ax1.plot(self.epochs, self.val_losses, label='Validation Loss', marker='s')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Accuracy plot
        ax2.plot(self.epochs, self.train_accuracies, label='Train Accuracy', marker='o')
        ax2.plot(self.epochs, self.val_accuracies, label='Validation Accuracy', marker='s')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy')
        ax2.set_title('Training and Validation Accuracy')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.output_dir, 'learning_curves.png'), dpi=300, bbox_inches='tight')
        plt.close()


def set_reproducible_seed(seed):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # Make CUDA deterministic
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    # Set transformers seed
    transformers.set_seed(seed)


def get_system_info():
    """Collect system information for reproducibility"""
    info = {
        'python_version': sys.version,
        'torch_version': torch.__version__,
        'transformers_version': transformers.__version__,
        'peft_version': peft.__version__,
        'numpy_version': np.__version__,
        'cuda_available': torch.cuda.is_available(),
        'platform': sys.platform,
        'timestamp': datetime.datetime.now().isoformat()
    }
    
    if torch.cuda.is_available():
        info.update({
            'cuda_version': torch.version.cuda,
            'cudnn_version': torch.backends.cudnn.version(),
            'gpu_count': torch.cuda.device_count(),
            'gpu_names': [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
        })
    
    # Try to get git info
    try:
        git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()
        info['git_commit'] = git_hash
        
        # Check for uncommitted changes
        try:
            git_status = subprocess.check_output(['git', 'status', '--porcelain']).decode().strip()
            info['git_dirty'] = len(git_status) > 0
            if info['git_dirty']:
                info['git_changes'] = git_status
        except:
            pass
    except:
        info['git_commit'] = 'unavailable'
    
    return info


def save_confusion_matrix(y_true, y_pred, class_names, output_path):
    """Save confusion matrix as both plot and JSON"""
    # Generate list of all possible labels (0 to num_classes-1)
    all_labels = list(range(len(class_names)))
    cm = confusion_matrix(y_true, y_pred, labels=all_labels)
    
    # Save as plot
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(output_path.replace('.json', '.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # Save as JSON
    # Generate list of all possible labels (0 to num_classes-1)
    all_labels = list(range(len(class_names)))
    
    cm_data = {
        'matrix': cm.tolist(),
        'class_names': class_names,
        'accuracy': accuracy_score(y_true, y_pred),
        'classification_report': classification_report(y_true, y_pred, 
                                                      labels=all_labels,
                                                      target_names=class_names, 
                                                      output_dict=True,
                                                      zero_division=0)
    }
    
    with open(output_path, 'w') as f:
        json.dump(cm_data, f, indent=2)


def generate_model_card(args, system_info, experiment_info, metrics, output_dir):
    """Generate comprehensive model card"""
    model_card = f"""# BERT-LoRA Text Classification Model

    ## Model Description
    - **Base Model**: {args.model_name}
    - **Task**: AG News Classification (4 classes: World, Sports, Business, Technology)
    - **Training Method**: LoRA (Low-Rank Adaptation)
    - **Experiment ID**: {experiment_info['experiment_id']}

    ## Training Configuration
    - **Epochs**: {args.num_epochs}
    - **Batch Size**: {args.batch_size}
    - **Learning Rate**: {args.learning_rate}
    - **Max Sequence Length**: {args.max_seq_length}
    - **LoRA Rank (r)**: {args.lora_r}
    - **LoRA Alpha**: {args.lora_alpha}
    - **LoRA Dropout**: {args.lora_dropout}
    - **Target Modules**: {['query', 'value']}

    ## Dataset Information
    - **Dataset**: AG News
    - **Train Size**: {experiment_info['dataset_info']['splits']['train']}
    - **Validation Size**: {experiment_info['dataset_info']['splits']['val']}
    - **Test Size**: {experiment_info['dataset_info']['splits']['test']}
    - **Number of Classes**: {experiment_info['dataset_info']['num_classes']}

    ## Model Performance
    - **Test Accuracy**: {metrics['test_accuracy']:.4f}
    - **Test Loss**: {metrics['test_loss']:.4f}
    - **Validation Accuracy**: {metrics['val_accuracy']:.4f}
    - **Validation Loss**: {metrics['val_loss']:.4f}
    - **Train Accuracy**: {metrics['train_accuracy']:.4f}
    - **Train Loss**: {metrics['train_loss']:.4f}

    ## Model Architecture
    - **Total Parameters**: {experiment_info['model_info']['total_params']:,}
    - **Trainable Parameters**: {experiment_info['model_info']['trainable_params']:,}
    - **Trainable Percentage**: {(experiment_info['model_info']['trainable_params'] / experiment_info['model_info']['total_params'] * 100):.2f}%

    ## System Information
    - **Python Version**: {system_info['python_version'].split()[0]}
    - **PyTorch Version**: {system_info['torch_version']}
    - **Transformers Version**: {system_info['transformers_version']}
    - **PEFT Version**: {system_info['peft_version']}
    - **CUDA Available**: {system_info['cuda_available']}
    - **Training Date**: {system_info['timestamp']}
    - **Git Commit**: {system_info.get('git_commit', 'unavailable')}

    ## Usage Example
    ```python
    from transformers import BertForSequenceClassification, BertTokenizerFast

    # Load model and tokenizer
    model = BertForSequenceClassification.from_pretrained('./merged_output_dir')
    tokenizer = BertTokenizerFast.from_pretrained('./merged_output_dir')

    # Example inference
    text = "Apple stock rises after strong quarterly earnings report"
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(predictions, dim=-1).item()

    # Class mapping: 0=World, 1=Sports, 2=Business, 3=Technology
    class_names = ['World', 'Sports', 'Business', 'Technology']
    print(f"Predicted class: {{class_names[predicted_class]}}")
    ```

    ## Training Command
    ```bash
    {' '.join(sys.argv)}
    ```

    ## Files Structure
    - `merged_output_dir/`: Complete merged model ready for inference
    - `config.json`: Model configuration
    - `pytorch_model.bin`: Model weights
    - `tokenizer.json`: Tokenizer configuration
    - `experiment_config.json`: Complete experiment configuration
    - `learning_curves.png`: Training progress visualization
    - `confusion_matrix_test.png`: Test set confusion matrix
    - `training.log`: Detailed training logs

    ## Limitations
    - Trained specifically on AG News dataset
    - May not generalize well to other text classification tasks
    - Limited to 4 predefined news categories
    - Maximum sequence length of {args.max_seq_length} tokens

    ## Citation
    If you use this model, please cite:
    ```bibtex
    @misc{{bert_lora_ag_news,
    title={{BERT-LoRA Text Classification Model}},
    author={{Your Name}},
    year={{2025}},
    note={{Trained on AG News dataset using LoRA fine-tuning}}
    }}
    ```
    """
        
    with open(os.path.join(output_dir, 'MODEL_CARD.md'), 'w') as f:
        f.write(model_card)




def generate_model_card(
    args,
    system_info: dict,
    experiment_info: dict,
    metrics: dict,
    output_dir: str,
    tuning_method: str = "full",               # either "full" or "lora"
    # lora_config: Optional[dict] = None         # only passed when tuning_method=="lora"
):
    """Generate comprehensive model card for either fully fine-tuned or LoRA-tuned BERT."""
    lines = []
    # --- Header & model description ---
    lines.append("# Text Classification Model\n")
    lines.append("## Model Description")
    lines.append(f"- **Base Model**: {args.model_name}")
    lines.append(f"- **Task**: {args.dataset_name} Classification")
    lines.append(f"- **Experiment ID**: {experiment_info['experiment_id']}")
    lines.append(f"- **Tuning Method**: {tuning_method.capitalize()}")
    lines.append("")  # blank line

    # --- Training Configuration (common hyperparameters) ---
    lines.append("## Training Configuration")
    lines.append(f"- **Epochs**: {args.num_epochs}")
    lines.append(f"- **Batch Size**: {args.batch_size}")
    lines.append(f"- **Learning Rate**: {args.learning_rate}")
    lines.append(f"- **Max Sequence Length**: {args.max_seq_length}")

    # --- LoRA section (only if requested) ---
    if tuning_method.lower() == "lora":
        lines.append("\n## LoRA Hyperparameters")
        lines.append(f"- **LoRA Rank (r)**: {args.lora_r}")
        lines.append(f"- **LoRA Alpha**: {args.lora_alpha}")
        lines.append(f"- **LoRA Dropout**: {args.lora_dropout}")
        lines.append(f"- **Target Modules**: {args.lora_target_modules}")
    
    lines.append("")  # blank line

    # --- Dataset info ---
    ds = experiment_info['dataset_info']
    splits = ds['splits']
    lines += [
        "## Dataset Information",
        f"- **Dataset**: {args.dataset_name}",
        f"- **Train Size**: {splits['train']}",
        f"- **Validation Size**: {splits['val']}",
        f"- **Test Size**: {splits['test']}",
        f"- **Number of Classes**: {ds['num_classes']}",
        ""
    ]

    # --- Metrics ---
    lines += [
        "## Model Performance",
        f"- **Train Accuracy**: {metrics['train_accuracy']:.4f}",
        f"- **Train Loss**: {metrics['train_loss']:.4f}",
        f"- **Validation Accuracy**: {metrics['val_accuracy']:.4f}",
        f"- **Validation Loss**: {metrics['val_loss']:.4f}",
        f"- **Test Accuracy**: {metrics['test_accuracy']:.4f}",
        f"- **Test Loss**: {metrics['test_loss']:.4f}",
        ""
    ]

    # --- Architecture & system info (unchanged) ---
    arch = experiment_info['model_info']
    train_pct = arch['trainable_params'] / arch['total_params'] * 100
    lines += [
        "## Model Architecture",
        f"- **Total Parameters**: {arch['total_params']:,}",
        f"- **Trainable Parameters**: {arch['trainable_params']:,}",
        f"- **Trainable %**: {train_pct:.2f}%",
        "",
        "## System Information",
        f"- **Python Version**: {system_info['python_version'].split()[0]}",
        f"- **PyTorch Version**: {system_info['torch_version']}",
        f"- **Transformers Version**: {system_info['transformers_version']}",
        f"- **PEFT Version**: {system_info['peft_version']}",
        f"- **CUDA Available**: {system_info['cuda_available']}",
        f"- **Training Date**: {system_info['timestamp']}",
        f"- **Git Commit**: {system_info.get('git_commit','unavailable')}",
        ""
    ]

    # --- Usage, command, files, limitations, citation (same as before) ---
    lines += [
        "## Usage Example",
        "```python",
        "from transformers import BertForSequenceClassification, BertTokenizerFast",
        "",
        "model = BertForSequenceClassification.from_pretrained('./merged_output_dir')",
        "tokenizer = BertTokenizerFast.from_pretrained('./merged_output_dir')",
        "…",
        "```",
        "",
        "## Training Command",
        "```bash",
        " ".join(sys.argv),
        "```",
        "",
        "## Files Structure",
        "- `merged_output_dir/`: …",
        "- `config.json`: …",
        "- …",
        "",
        "## Limitations",
        "- Trained on AG News only",
        "- …",
        "",
        "## Citation",
        "```bibtex",
        "@misc{bert_text_classification,",
        "  title={…},",
        "  author={…},",
        "  year={2025},",
        "}",
        "```",
    ]

    model_card = "\n".join(lines)
    with open(os.path.join(output_dir, "MODEL_CARD.md"), "w") as f:
        f.write(model_card)
