import argparse
from shutil import copyfile
import importlib
import os
import utils
import logging

import numpy as np
import torch
import torch.nn as nn

from torch.amp import autocast, GradScaler
import models
from dataloaders import CodeDataset
from target_functions import *
from tqdm import tqdm
from itertools import islice

def log_model_architecture(model: nn.Module, logger: logging.Logger) -> None:
    """
    Dump the full torch.nn.Module repr and (if available) the Hugging Face config
    into the existing logger file.
    """
    # Full module tree
    try:
        logger.info("\n===== MODEL ARCHITECTURE (repr) =====\n%s", repr(model))
    except Exception as e:
        logger.warning("Failed to repr(model): %s", e)

    # HF config (handles wrappers that store the real model on `.model`)
    base = getattr(model, "model", model)
    cfg = getattr(base, "config", None)
    if cfg is not None:
        try:
            to_json = getattr(cfg, "to_json_string", None)
            cfg_str = to_json(use_diff=False) if callable(to_json) else str(cfg)
            logger.info("\n===== MODEL CONFIG =====\n%s", cfg_str)
        except Exception as e:
            logger.warning("Failed to serialize model config: %s", e)

    # Names of trainable params (quick skim without sizes)
    try:
        trainable = [n for n, p in model.named_parameters() if p.requires_grad]
        logger.info("\n===== TRAINABLE PARAMETER NAMES (%d) =====\n%s",
                    len(trainable), "\n".join(trainable))
    except Exception as e:
        logger.warning("Failed to enumerate trainable params: %s", e)

def setup_logger(log_file_name="job_log.log", log_level=logging.INFO, rank=0):
    """
    Sets up a logger that writes to a specified file and to the console.
    Only rank 0 will have full logging. Other ranks can have minimal error logging.
    """
    logger = logging.getLogger(__name__ + f"_rank{rank}")
    logger.handlers.clear() # Clear existing handlers, important if called multiple times
    logger.setLevel(log_level)

    if rank == 0:
        # --- Create file handler ---
        file_handler = logging.FileHandler(log_file_name, mode='a')
        file_handler.setLevel(log_level)

        # --- Create console handler ---
        console_handler = logging.StreamHandler()
        console_handler.setLevel(log_level)

        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )
        file_handler.setFormatter(formatter)
        console_handler.setFormatter(formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)
    else:
        # For non-rank 0 processes, either add a NullHandler or a minimal error handler
        logger.addHandler(logging.NullHandler())

    logger.propagate = False # Avoid double logging if root logger is configured
    return logger


class Metrics:
    def __init__(self):
        self.train_losses = []
        self.test_losses = []
        self.train_accuracies = []
        self.test_accuracies = []

    def update(self, train_loss, test_loss, train_acc, test_acc):

        self.train_losses.append(train_loss)
        self.test_losses.append(test_loss)
        self.train_accuracies.append(train_acc)
        self.test_accuracies.append(test_acc)

def evaluate_model(model, dataloader, device, model_name, amp_dtype):
    model.eval()
    total_loss, total_acc = 0.0, 0.0

    is_finetune = "finetune" in model_name # Helper variable
    is_binary = model_name in ['llama3', 'deepseek', 'qwen3BCoder', 'qwen7BCoder', 'qwen1.5BCoder', 'qwen1.7B', 'qwen1.5B', 'qwen0.6B', 'bloom', 'deberta', 'mlp'] or is_finetune
    criterion = nn.BCEWithLogitsLoss() if is_binary else nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in dataloader: # CHANGE: iterate over batch dictionary
            if is_finetune:
                # Data is a dictionary from the collator
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
            else:
                # Original logic
                input_ids, labels = batch
                input_ids = input_ids.to(device)
                attention_mask = None # Or create one if needed

            if is_binary:
                labels = labels.to(device, dtype=torch.float)
                with autocast(device_type='cuda', dtype=amp_dtype):
                    preds = model(input_ids, attention_mask=attention_mask) # Pass attention mask
                    loss = criterion(preds, labels.view(-1, 1))
                probs = torch.sigmoid(preds)
                preds_bin = (probs >= 0.5).float()
                total_acc += (preds_bin == labels.view(-1, 1)).float().sum().item()
            else:
                # CE expects long targets of shape (N,)
                labels = labels.to(device, dtype=torch.long)
                with autocast(device_type='cuda', dtype=amp_dtype):
                    # preds = model(data)                    # (N, T, C) or (N, C)
                    preds = model(input_ids, attention_mask=attention_mask)
                    logits = preds[:, -1] if preds.dim() == 3 else preds  # (N, C)
                    loss   = criterion(logits, labels)
                total_acc += (logits.argmax(1) == labels).float().sum().item()
            total_loss += loss.item() * input_ids.size(0)

    avg_loss = total_loss / len(dataloader.dataset)
    avg_acc  = total_acc / len(dataloader.dataset)
    return avg_loss, avg_acc

def train_epoch(model, optimizer, scheduler, dataloader, device, train_iters, scaler, model_name, amp_dtype, logger):
    model.train()

    is_finetune = "finetune" in model_name # Helper variable
    is_binary = model_name in ['llama3', 'deepseek', 'qwen3BCoder', 'qwen7BCoder', 'qwen1.5BCoder', 'qwen1.7B', 'qwen1.5B', 'qwen0.6B', 'bloom', 'deberta', 'mlp'] or is_finetune
    criterion = nn.BCEWithLogitsLoss() if is_binary else nn.CrossEntropyLoss()
    for batch_idx, batch in enumerate(islice(dataloader, train_iters)):
        optimizer.zero_grad(set_to_none=True)
        
        if is_finetune:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
        else:
            input_ids, labels = batch
            input_ids = input_ids.to(device)
            attention_mask = None

        with autocast(device_type='cuda', dtype=amp_dtype):
            if is_binary:
                labels = labels.to(device, dtype=torch.float)
                preds = model(input_ids, attention_mask=attention_mask) # Pass attention mask
                loss  = criterion(preds, labels.view(-1, 1))
            else:
                labels = labels.to(device, dtype=torch.long)
                # preds = model(data)                        # (N, T, C) or (N, C)
                preds = model(input_ids, attention_mask=attention_mask)
                logits = preds[:, -1] if preds.dim() == 3 else preds
                loss  = criterion(logits, labels)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()


def main(results_path, settings_path = "./conf/settings.py"):
    torch.backends.cuda.matmul.allow_tf32 = True  # Add to main()
    torch.backends.cudnn.allow_tf32 = True
    # torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False)

    
    settings = utils.load_settings(settings_path)
    

    logger = setup_logger(
        log_file_name=os.path.join(results_path, "logs.log") if results_path else "job_log.log",
        log_level=logging.INFO,
    )

    tokenizer_name = None
    if settings.model == "llama3_finetune":
        tokenizer_name = "meta-llama/Llama-3.2-1B"
    elif settings.model == "qwen3_finetune":
        tokenizer_name = "Qwen/Qwen2-1.5B"
    elif settings.model == "deepseek_finetune":
        tokenizer_name = "deepseek-ai/deepseek-coder-1.3b-base"

    pattern = prime = palindrome = dyck2 = prime_odd = False
    if settings.target_func in ['func7', 'func1', 'func15', 'func18', 'func3', 'func4']:
        pass
    elif settings.target_func in ['func20', 'func21']:
        if settings.target_func == 'func20':
            pattern = '10101010'
        if settings.target_func == 'func21':
            pattern = '00111111'
    elif settings.target_func in ['func17']:
        palindrome = True
    elif settings.target_func in ['func16']:
        dyck2 = True
    elif settings.target_func in ['func19']:
        prime = True
    elif settings.target_func in ['func22']:
        prime_odd = True
    else:
        raise Exception(f"Target function support not implemented{settings.target_func}")

    dataset = CodeDataset(
        globals()[settings.target_func],
        settings.sequence_length,
        settings.train_set_size,
        settings.test_set_size,
        settings.batch_size,
        bos_token = settings.BOS_TOKEN,
        online = settings.online,
        device = settings.device,
        dyck2 = dyck2,
        palindrome = palindrome,
        logger = logger,
        prime = prime,
        pattern = pattern,
        prime_odd = prime_odd,
        tokenizer_name = tokenizer_name,
    )

    # Create dataloaders
    train_loader, test_loader = dataset.create_dataloaders( )

    # Model setup
    model = models.get_model(settings).to(settings.device)

    if settings.model == "bloom":
        base = getattr(model, "model", model)
        try:
            base.set_attention_implementation("flash_attention_2")
        except Exception:
            pass
        model = torch.compile(model, mode="reduce-overhead")

    log_model_architecture(model, logger)
    optimizer = torch.optim.AdamW(model.parameters(), lr=settings.lr, weight_decay=settings.weight_decay, fused=True)

    # Calculate training iterations
    train_iters = len(train_loader)
    total_steps = settings.n_epochs * train_iters
    if settings.model == "bloom":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=settings.eta_min)
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=settings.eta_min)
        # NOTE : While doing ablation for Learning rate uncomment the line below.
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=settings.lr)
    
    metrics = Metrics()

    logger.info(f'num params:, {sum(x.numel() for x in model.parameters())//10**3} K')

    def count_trainable_params(model):
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())
        logger.info(f"Trainable Parameters: {trainable_params:,}")
        logger.info(f"Total Parameters: {total_params:,}")
        logger.info(f"Percentage of Trainable Params: {100 * trainable_params / total_params:.4f}%")

    count_trainable_params(model)

    if settings.precision in ['bf16', 'bfloat16']:
        amp_dtype = torch.bfloat16
        scaler = GradScaler(enabled=False)
    elif settings.precision in ['f16', 'float16']:
        amp_dtype = torch.float16
        scaler = GradScaler(enabled=True)
    elif settings.precision in ['f32', 'float32']:
        amp_dtype = torch.float32
        scaler = GradScaler(enabled=False)

    for epoch in tqdm( range(1, settings.n_epochs + 1) ):
        # Evaluation phase

        if not settings.online:
            # Change the evaluation according to the requirements.
            if ( ( epoch % 100 ) == 1 ) or ( epoch == settings.n_epochs ):
                train_loss, train_acc = evaluate_model(model, train_loader, settings.device, settings.model, amp_dtype)
                test_loss, test_acc = evaluate_model(model, test_loader, settings.device, settings.model, amp_dtype)
            
                metrics.update(train_loss, test_loss, train_acc, test_acc)

                logger.info(f"Epoch {epoch:03d} | Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")
            
            # Training phase
            train_epoch(model, optimizer, scheduler, train_loader, settings.device, train_iters, scaler, settings.model, amp_dtype, logger)
            logger.info( f"PRINTING RESULT PATH {results_path}" )    
            # Save progress
            utils.save_data(results_path, metrics)

            # # Save if memory permits (can go in tb for all the experiments ran, hence not done).
            # if epoch%100 == 1 or ( epoch == settings.n_epochs ):
            #     model_file_name=os.path.join(results_path, f"model{epoch}.log")
            #     torch.save(model.state_dict(), model_file_name)
    
    try:
        # make sure no grads are hanging around
        optimizer.zero_grad(set_to_none=True)
    except Exception:
        pass
    import gc
    del model
    # delete GPU-heavy refs in one go
    del scaler
    del scheduler
    del optimizer
    del train_loader, test_loader, dataset
    gc.collect()
    torch.cuda.empty_cache()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--results-path', action='store', default=None,
                        help='Specifies path of directory where results should be stored (including weights) and '
                             'where the settings.py file is located.')
    parser.add_argument("--settings_path", default="./conf/settings.py")
    args = parser.parse_args()
    
    settings_path = os.path.abspath(args.settings_path)
    print(f"[main] Using settings: {settings_path}")
    main(results_path=args.results_path, settings_path=settings_path)
    # main(args.results_path)