#!/usr/bin/env python3
"""
annDNA Training Script
Unified training for seq, struct, full
"""

import sys
sys.path.append('..')

import os
import argparse
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
from tqdm import tqdm
from datetime import datetime

from config import get_model_config, get_model_paths
from model import annDNA
from dataset import GenomeDataset


class Trainer:
    def __init__(self, model_name, gpu_ids='0', batch_size=96, learning_rate=3e-5,
                 epochs=10, use_wandb=False):

        # Config
        self.model_name = model_name
        self.config = get_model_config(model_name)
        self.paths = get_model_paths(model_name)

        # GPU settings
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Training params
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.use_wandb = use_wandb

        # Fixed params
        self.weight_decay = 0.01
        self.warmup_steps = 1000
        self.gradient_clip = 1.0
        self.log_every = 100
        self.validate_every = 1000

        print(f"=== Training {self.config['name']} ===")
        print(f"Description: {self.config['description']}")
        print(f"Device: {self.device}")
        print(f"GPUs: {gpu_ids}")
        print(f"W&B logging: {'enabled' if use_wandb else 'disabled'}")

        # Load vocabulary
        with open(self.paths['vocab_file'], 'r') as f:
            self.vocab = json.load(f)
        self.vocab_size = len(self.vocab)
        print(f"Vocabulary size: {self.vocab_size}")

        # Create model
        self.model = annDNA(
            vocab_size=self.vocab_size,
            d_model=self.config['d_model'],
            nhead=self.config['nhead'],
            num_layers=self.config['num_layers'],
            max_seq_len=self.config['max_seq_len']
        )

        # Multi-GPU
        if torch.cuda.device_count() > 1:
            print(f"Using {torch.cuda.device_count()} GPUs")
            self.model = nn.DataParallel(self.model)

        self.model = self.model.to(self.device)

        # Print model info
        model_info = self.model.module.get_model_info() if isinstance(self.model, nn.DataParallel) else self.model.get_model_info()
        print("\nModel:")
        for key, value in model_info.items():
            print(f"  {key}: {value}")

        # Data loaders
        self.train_loader = DataLoader(
            GenomeDataset(self.paths['processed_dir'], 'train'),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            drop_last=True  # Important for consistent batch sizes
        )

        # Check if validation data exists
        val_file = self.paths['processed_dir'] / 'val_input_ids.npy'
        if val_file.exists():
            self.val_loader = DataLoader(
                GenomeDataset(self.paths['processed_dir'], 'val'),
                batch_size=self.batch_size * 2,
                shuffle=False,
                num_workers=4,
                pin_memory=True
            )
            print(f"Train batches: {len(self.train_loader):,}")
            print(f"Val batches: {len(self.val_loader):,}")
        else:
            self.val_loader = None
            print(f"Train batches: {len(self.train_loader):,}")
            print("Validation data not found - skipping validation")

        # Optimizer and scheduler
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )

        total_steps = len(self.train_loader) * self.epochs
        def lr_lambda(step):
            if step < self.warmup_steps:
                return step / self.warmup_steps
            else:
                progress = (step - self.warmup_steps) / (total_steps - self.warmup_steps)
                return 0.5 * (1 + np.cos(np.pi * progress))

        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)

        # Loss function
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)

        # Training state
        self.global_step = 0
        self.best_val_loss = float('inf')

        # Create model directory
        self.paths['model_dir'].mkdir(parents=True, exist_ok=True)

        # Initialize logging
        if self.use_wandb:
            import wandb
            self.wandb = wandb
            self.init_wandb(model_info)
        else:
            log_name = f"train_log_{datetime.now().strftime('%Y%m%d-%H%M%S')}.txt"
            self.log_file = open(self.paths['model_dir'] / log_name, 'w')
            self.log(f"annDNA {self.model_name} Training Log\n")
            self.log(f"{'='*50}\n")
            for key, value in model_info.items():
                self.log(f"{key}: {value}\n")
            self.log(f"{'='*50}\n\n")

    def init_wandb(self, model_info):
        """Initialize Weights & Biases logging"""
        run_name = f"anndna-{self.model_name}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"

        self.wandb.init(
            project=f"anndna-{self.model_name}",
            name=run_name,
            config={
                "model": model_info,
                "training": {
                    "batch_size": self.batch_size,
                    "learning_rate": self.learning_rate,
                    "weight_decay": self.weight_decay,
                    "epochs": self.epochs,
                    "warmup_steps": self.warmup_steps,
                    "gradient_clip": self.gradient_clip
                },
                "data": {
                    "window_size": 1000,
                    "vocab_size": self.vocab_size
                }
            }
        )

    def log(self, message):
        """Log to text file (when W&B is disabled)"""
        if not self.use_wandb and hasattr(self, 'log_file'):
            self.log_file.write(message)
            self.log_file.flush()

    def calculate_accuracy(self, logits, labels):
        mask = (labels != -100)
        if mask.sum() == 0:
            return 0.0

        predictions = torch.argmax(logits, dim=-1)
        correct = (predictions == labels) & mask
        return (correct.sum().float() / mask.sum().float()).item()

    def validate(self):
        self.model.eval()
        total_loss = 0
        total_accuracy = 0
        num_batches = 0

        with torch.no_grad():
            for batch in self.val_loader:
                input_ids = batch['input_ids'].to(self.device)
                labels = batch['labels'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)

                logits = self.model(input_ids, attention_mask)
                loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

                accuracy = self.calculate_accuracy(logits, labels)

                total_loss += loss.item()
                total_accuracy += accuracy
                num_batches += 1

        avg_loss = total_loss / num_batches
        avg_accuracy = total_accuracy / num_batches

        self.model.train()
        return avg_loss, avg_accuracy

    def save_checkpoint(self, epoch, loss, is_best=False):
        model_state = self.model.module.state_dict() if isinstance(self.model, nn.DataParallel) else self.model.state_dict()
        checkpoint = {
            'epoch': epoch,
            'global_step': self.global_step,
            'model_state_dict': model_state,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': loss,
            'vocab_size': self.vocab_size
        }

        if is_best:
            torch.save(checkpoint, self.paths['model_dir'] / "best_model.pt")
            print(f"Saved best model (loss: {loss:.4f})")

        torch.save(checkpoint, self.paths['model_dir'] / f"epoch_{epoch}.pt")
        print(f"Saved checkpoint for epoch {epoch}")

    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        total_accuracy = 0
        num_batches = 0

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")

        for batch in pbar:
            input_ids = batch['input_ids'].to(self.device)
            labels = batch['labels'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)

            # Forward
            logits = self.model(input_ids, attention_mask)
            loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))

            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip)
            self.optimizer.step()
            self.scheduler.step()

            # Metrics
            accuracy = self.calculate_accuracy(logits, labels)
            total_loss += loss.item()
            total_accuracy += accuracy
            num_batches += 1
            self.global_step += 1

            # Update progress
            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'acc': f"{accuracy:.4f}",
                'lr': f"{self.scheduler.get_last_lr()[0]:.2e}"
            })

            # Log training metrics
            if self.global_step % self.log_every == 0:
                if self.use_wandb:
                    self.wandb.log({
                        'train/loss': loss.item(),
                        'train/accuracy': accuracy,
                        'train/perplexity': np.exp(loss.item()),
                        'train/learning_rate': self.scheduler.get_last_lr()[0],
                        'train/step': self.global_step,
                        'train/epoch': epoch
                    })
                else:
                    self.log(f"Step {self.global_step} | Loss: {loss.item():.4f} | Acc: {accuracy:.4f} | "
                            f"LR: {self.scheduler.get_last_lr()[0]:.2e}\n")

            # Validation
            if self.val_loader and self.global_step % self.validate_every == 0:
                val_loss, val_accuracy = self.validate()
                print(f"\nStep {self.global_step} - Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

                if self.use_wandb:
                    self.wandb.log({
                        'val/loss': val_loss,
                        'val/accuracy': val_accuracy,
                        'val/perplexity': np.exp(val_loss)
                    })
                else:
                    self.log(f"[VALIDATION] Step {self.global_step} | Val Loss: {val_loss:.4f} | "
                            f"Val Acc: {val_accuracy:.4f}\n")

                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    self.save_checkpoint(epoch, val_loss, is_best=True)

        # Epoch summary
        avg_loss = total_loss / num_batches
        avg_accuracy = total_accuracy / num_batches
        print(f"Epoch {epoch} - Avg Loss: {avg_loss:.4f}, Avg Acc: {avg_accuracy:.4f}")

        if not self.use_wandb:
            self.log(f"\n[EPOCH {epoch} SUMMARY] Avg Loss: {avg_loss:.4f} | Avg Acc: {avg_accuracy:.4f}\n")

        return avg_loss

    def train(self):
        print(f"\nStarting training...")
        print(f"Total epochs: {self.epochs}")
        print(f"Total steps: {len(self.train_loader) * self.epochs:,}")

        for epoch in range(1, self.epochs + 1):
            print(f"\nEpoch {epoch}/{self.epochs}")
            epoch_loss = self.train_epoch(epoch)

            # Run validation at end of epoch
            if self.val_loader:
                val_loss, val_accuracy = self.validate()
                print(f"Epoch {epoch} Validation - Loss: {val_loss:.4f}, Acc: {val_accuracy:.4f}")

                # Log epoch validation metrics
                if self.use_wandb:
                    self.wandb.log({
                        'epoch/val_loss': val_loss,
                        'epoch/val_accuracy': val_accuracy,
                        'epoch/train_loss': epoch_loss,
                        'epoch/number': epoch
                    })
                else:
                    self.log(f"[EPOCH {epoch} VALIDATION] Val Loss: {val_loss:.4f} | Val Acc: {val_accuracy:.4f}\n\n")

                # Save checkpoint
                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    self.save_checkpoint(epoch, val_loss, is_best=True)
                else:
                    self.save_checkpoint(epoch, val_loss, is_best=False)
            else:
                # No validation - save checkpoint based on training loss
                if self.use_wandb:
                    self.wandb.log({
                        'epoch/train_loss': epoch_loss,
                        'epoch/number': epoch
                    })
                self.save_checkpoint(epoch, epoch_loss, is_best=(epoch == self.epochs))

        # Save final model
        final_model_state = self.model.module.state_dict() if isinstance(self.model, nn.DataParallel) else self.model.state_dict()
        torch.save({
            'model_state_dict': final_model_state,
            'vocab_size': self.vocab_size,
            'final_loss': self.best_val_loss
        }, self.paths['model_dir'] / "final_model.pt")

        print(f"\nTraining complete!")
        if self.val_loader:
            print(f"Best validation loss: {self.best_val_loss:.4f}")
        else:
            print(f"Final training loss: {epoch_loss:.4f}")

        if self.use_wandb:
            self.wandb.finish()
        else:
            self.log(f"\n{'='*50}\n")
            self.log(f"Training Complete!\n")
            self.log(f"Best validation loss: {self.best_val_loss:.4f}\n")
            self.log_file.close()


def main():
    parser = argparse.ArgumentParser(description='Train annDNA')
    parser.add_argument('--model', required=True,
                        choices=['seq', 'struct', 'full'],
                        help='Model to train')
    parser.add_argument('--gpu', default='0', help='GPU IDs (e.g., "0,1,2")')
    parser.add_argument('--batch-size', type=int, default=96, help='Batch size')
    parser.add_argument('--lr', type=float, default=3e-5, help='Learning rate')
    parser.add_argument('--epochs', type=int, default=10, help='Number of epochs')
    parser.add_argument('--wandb', action='store_true', help='Enable W&B logging')

    args = parser.parse_args()

    print("=" * 50)
    print("annDNA Training")
    print("=" * 50)

    trainer = Trainer(
        model_name=args.model,
        gpu_ids=args.gpu,
        batch_size=args.batch_size,
        learning_rate=args.lr,
        epochs=args.epochs,
        use_wandb=args.wandb
    )
    trainer.train()


if __name__ == "__main__":
    main()
