import torch
import numpy as np
import random
import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
try:
    from .utils import BufferCustom
    from .model import Autoencoder, AutoencoderCustom
    from .config import setup_training_config
except ImportError:
    # Fallback for direct script execution
    from utils import BufferCustom
    from model import Autoencoder, AutoencoderCustom
    from config import setup_training_config
import argparse
import pprint
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

def parse_args():
    parser = argparse.ArgumentParser(description="Train layer-specific KV dictionaries through direct gradient-based optimization")
    parser.add_argument("--model_name_or_path", type=str, required=True, help="Path or identifier of the pretrained model")
    parser.add_argument("--dictionary_size", type=int, default=4096, help="Size of the dictionary")
    parser.add_argument("--sparsity", type=int, default=8, help="Sparsity level for approximation")
    parser.add_argument("--num_epochs", type=int, default=20, help="Number of epochs for training")
    parser.add_argument("--eval_interval", type=int, default=1, help="Evaluation interval in epochs")
    parser.add_argument("--batch_size", type=int, default=64, help="Training batch size")
    parser.add_argument("--lm_batch_size", type=int, default=16, help="Batch size for forward pass of language model")
    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
    parser.add_argument("--buffer_mult", type=int, default=384, help="Multiplier determining buffer size for KV storage")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--use_norm", action='store_true', help="apply normalization to train loss")
    parser.add_argument("--concat", type=int, default=1, help="How much concat layers")
    parser.add_argument("--layer_idx", type=int, default=-1, help="layer index to be trained")
    
    # Path configuration arguments (optional)
    parser.add_argument("--checkpoint_dir", type=str, help="Directory for saving checkpoints")
    parser.add_argument("--dictionary_dir", type=str, help="Directory for saving dictionaries") 
    parser.add_argument("--runs_dir", type=str, help="Directory for TensorBoard logs")
    parser.add_argument("--data_base_dir", type=str, help="Base directory for training data")
    
    return vars(parser.parse_args())

def main(cfg_dict):
    # Setup configuration manager while preserving argparse compatibility
    cfg = setup_training_config(cfg_dict)
    
    # Set random seed0s
    SEED = cfg["seed"]
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    
    # Initialize buffer and update feature dimension
    # Pass the raw dict to BufferCustom for backward compatibility
    buffer = BufferCustom(cfg.copy())
    cfg['feature_dim'] = buffer.get_feature_dim()
    
    # Ensure directories exist
    cfg.ensure_directories()
    
    # Setup tensorboard writer
    writer = SummaryWriter(log_dir=cfg.get_tensorboard_log_dir())
    
    # Pass the ConfigManager to AutoencoderCustom (it handles both dict and ConfigManager)
    # autoencoder = AutoencoderCustom(cfg)
    autoencoder = Autoencoder(cfg)

    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=cfg["lr"])
    scheduler = CosineAnnealingLR(optimizer, T_max=cfg["num_epochs"], eta_min=1e-5)

    batches_per_epoch = buffer.get_length() // cfg["batch_size"]
    test_batches_per_epoch = buffer.get_test_length() // cfg["batch_size"]

    for epoch in range(cfg["num_epochs"]):
        pbar = tqdm.trange(batches_per_epoch, desc=f"Epoch {epoch+1}")
        avg_loss = []
        avg_rel_recon_loss = []
        
        for i in pbar:
            kvs = buffer.next()
            loss_, k_hat, y = autoencoder(kvs)
            
            # Extract losses (tuple format)
            recon_loss = loss_[0]
            decorel_loss = loss_[1] if len(loss_) > 1 else 0
            loss = recon_loss  # + decorel_loss if needed
            
            loss.backward()
            autoencoder.normalise_decoder_weights()
            optimizer.step()
            optimizer.zero_grad()

            # Calculate metrics
            loss_value = loss.item()
            rel_recon_error = torch.mean((torch.norm(kvs - k_hat, dim=-1) / (torch.norm(kvs, dim=-1) + 1e-8))).item()
            
            # Log metrics
            writer.add_scalar('Loss/train', loss_value, epoch * batches_per_epoch + i)
            writer.add_scalar('RelativeReconstructionError/train', rel_recon_error, epoch * batches_per_epoch + i)
            
            avg_loss.append(loss_value)
            avg_rel_recon_loss.append(rel_recon_error)
            
            if i % 10 == 0:
                pbar.set_postfix(loss=loss_value, recon_err=rel_recon_error)
            
            del loss, k_hat, y
            
        # Evaluation phase
        test_pbar = tqdm.trange(test_batches_per_epoch, desc=f"Epoch {epoch+1} - Test")
        with torch.no_grad():
            test_avg_loss = []
            test_rel_recon_loss = []
            autoencoder.eval()
            
            for i in test_pbar:
                kvs = buffer.test_next()  # Use test_next() for test data
                loss_, k_hat, y = autoencoder(kvs)

                # Extract test losses
                recon_loss = loss_[0]
                decorel_loss = loss_[1] if len(loss_) > 1 else 0
                loss = recon_loss  # + decorel_loss if needed
                
                # Calculate test metrics
                loss_value = loss.item()
                rel_recon_error = torch.mean((torch.norm(kvs - k_hat, dim=-1) / (torch.norm(kvs, dim=-1) + 1e-8))).item()
                
                test_avg_loss.append(loss_value)
                test_rel_recon_loss.append(rel_recon_error)
                
                if i % 10 == 0:
                    test_pbar.set_postfix(loss=loss_value, recon_err=rel_recon_error)
                    
            autoencoder.train()

        # Log epoch-level metrics
        avg_train_loss = sum(avg_loss) / len(avg_loss)
        avg_train_recon_err = sum(avg_rel_recon_loss) / len(avg_rel_recon_loss)
        avg_test_loss = sum(test_avg_loss) / len(test_avg_loss)
        avg_test_recon_err = sum(test_rel_recon_loss) / len(test_rel_recon_loss)
        
        writer.add_scalar('Loss/epoch_train', avg_train_loss, epoch)
        writer.add_scalar('Loss/epoch_test', avg_test_loss, epoch)
        writer.add_scalar('RelativeReconstructionError/epoch_train', avg_train_recon_err, epoch)
        writer.add_scalar('RelativeReconstructionError/epoch_test', avg_test_recon_err, epoch)
        
        print(f"Epoch {epoch + 1}/{cfg['num_epochs']} - "
              f"Train Loss: {avg_train_loss:.6f}, Test Loss: {avg_test_loss:.6f}, "
              f"Train Recon Err: {avg_train_recon_err:.6f}, Test Recon Err: {avg_test_recon_err:.6f}")

        scheduler.step()
        autoencoder.save_dictionary(epoch)
    
    autoencoder.save_dictionary()
    print("Dictionary saved")

    writer.close()


if __name__ == "__main__":
    cfg = parse_args()
    pprint.pprint(cfg)
    
    main(cfg)
    