#!/usr/bin/env python3
"""
Step 2: Train Distilled Model with Embedding Distillation

Distilled learns from:
1. Hard labels (ground truth sequence) - CE loss
2. Teacher hidden states - MSE loss

Loss = alpha * CE(distilled, labels) + (1-alpha) * MSE(distilled_hidden, teacher_hidden)
"""

import sys
sys.path.append('..')

import os
import json
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from tqdm import tqdm
from datetime import datetime
import h5py
import pandas as pd

from config import get_model_paths, get_model_config
from distill_config import get_distill_paths, DISTILL_CONFIG, DISTILLED_CONFIG, TEACHER_CONFIG

sys.path.append('../2_train')
from model import annDNA


class DistillationDataset(Dataset):
    """Dataset for distillation training"""

    def __init__(self, token_dir, hidden_dir, split='train'):
        print(f"Loading {split} distillation data...")

        # Distilled data (from token_dir)
        self.input_ids = np.load(f'{token_dir}/{split}_distilled_input_ids.npy')
        self.labels = np.load(f'{token_dir}/{split}_distilled_labels.npy')
        self.attention_mask = np.load(f'{token_dir}/{split}_attention_mask.npy')

        # Teacher hidden states (from hidden_dir)
        self.h5_path = f'{hidden_dir}/{split}_hidden_states.h5'
        self.h5_file = None

        print(f"Loaded {len(self.input_ids):,} {split} samples")

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        if self.h5_file is None:
            self.h5_file = h5py.File(self.h5_path, 'r')

        return {
            'input_ids': torch.tensor(self.input_ids[idx], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_mask[idx], dtype=torch.long),
            'teacher_hidden': torch.tensor(
                self.h5_file['hidden_states'][idx], dtype=torch.float32
            )
        }


class DistilledModel(nn.Module):
    """Distilled model that outputs both logits and hidden states"""

    def __init__(self, vocab_size, d_model, nhead, num_layers, max_seq_len):
        super().__init__()

        self.d_model = d_model
        self.vocab_size = vocab_size

        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)

        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            activation='gelu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Output
        self.mlm_head = nn.Linear(d_model, vocab_size)
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)

        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=0.02)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.Embedding):
                module.weight.data.normal_(mean=0.0, std=0.02)
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)

    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.shape

        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)

        token_embeds = self.token_embedding(input_ids)
        pos_embeds = self.pos_embedding(pos_ids)
        embeddings = self.layer_norm(self.dropout(token_embeds + pos_embeds))

        mask = (attention_mask == 0) if attention_mask is not None else None
        hidden_states = self.transformer(embeddings, src_key_padding_mask=mask)

        logits = self.mlm_head(hidden_states)

        return logits, hidden_states

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

    def forward_with_attention(self, input_ids, attention_mask=None):
        """Forward pass with attention extraction"""
        batch_size, seq_len = input_ids.shape

        pos_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, -1)

        token_embeds = self.token_embedding(input_ids)
        pos_embeds = self.pos_embedding(pos_ids)
        embeddings = self.layer_norm(self.dropout(token_embeds + pos_embeds))

        key_padding_mask = (attention_mask == 0) if attention_mask is not None else None

        # Manual forward through each layer to extract attention
        hidden_states = embeddings
        all_attentions = []

        for layer in self.transformer.layers:
            # Self-attention with attention weights
            attn_output, attn_weights = layer.self_attn(
                query=hidden_states,
                key=hidden_states,
                value=hidden_states,
                key_padding_mask=key_padding_mask,
                need_weights=True,
                average_attn_weights=False
            )

            all_attentions.append(attn_weights)

            # Residual connection and layer norm
            hidden_states = layer.norm1(hidden_states + layer.dropout1(attn_output))

            # Feed forward
            ff_output = layer.linear2(layer.dropout(layer.activation(layer.linear1(hidden_states))))
            hidden_states = layer.norm2(hidden_states + layer.dropout2(ff_output))

        logits = self.mlm_head(hidden_states)

        return logits, hidden_states, all_attentions

    def get_attention(self, input_ids, attention_mask=None, layer_idx=None, head_idx=None):
        """
        Extract attention scores

        Args:
            layer_idx: None (all layers), int (specific layer), -1 (last layer)
            head_idx: None (all heads), int (specific head)

        Returns:
            numpy array of attention scores
        """
        self.eval()
        with torch.no_grad():
            _, _, attentions = self.forward_with_attention(input_ids, attention_mask)

        all_attentions = torch.stack(attentions, dim=0)
        all_attentions = all_attentions.cpu().numpy()

        if layer_idx is not None:
            num_layers = len(self.transformer.layers)
            if layer_idx == -1:
                layer_idx = num_layers - 1
            all_attentions = all_attentions[layer_idx]

        if head_idx is not None:
            all_attentions = all_attentions[..., head_idx, :, :]

        return all_attentions


class DistilledTrainer:
    def __init__(self, args):
        self.args = args

        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.num_gpus = torch.cuda.device_count()

        self.distill_paths = get_distill_paths()
        self.distilled_paths = get_model_paths(DISTILL_CONFIG['distilled_base'])

        with open(self.distilled_paths['vocab_file']) as f:
            self.vocab = json.load(f)
        self.vocab_size = len(self.vocab)

        print("=" * 60)
        print("Embedding Distillation Training")
        print("=" * 60)
        print(f"Teacher: {DISTILL_CONFIG['teacher_model']} ({TEACHER_CONFIG['num_layers']}L, {TEACHER_CONFIG['nhead']}H)")
        print(f"Distilled: {DISTILLED_CONFIG['num_layers']}L, {DISTILLED_CONFIG['nhead']}H, d={DISTILLED_CONFIG['d_model']}")
        print(f"Distilled vocab size: {self.vocab_size}")
        print(f"GPUs: {args.gpu} ({self.num_gpus} devices)")
        print(f"Alpha: {args.alpha} (CE weight)")

        # Create distilled model (smaller than teacher)
        self.model = DistilledModel(
            vocab_size=self.vocab_size,
            d_model=DISTILLED_CONFIG['d_model'],
            nhead=DISTILLED_CONFIG['nhead'],
            num_layers=DISTILLED_CONFIG['num_layers'],
            max_seq_len=DISTILLED_CONFIG['max_seq_len']
        )

        if self.num_gpus > 1:
            print(f"Using DataParallel with {self.num_gpus} GPUs")
            self.model = nn.DataParallel(self.model)

        self.model = self.model.to(self.device)

        num_params = self.model.module.get_num_params() if isinstance(self.model, nn.DataParallel) else self.model.get_num_params()
        print(f"Distilled parameters: {num_params:,}")

        # Data paths
        # self.token_dir = ''  # distillation path
        # self.hidden_dir = ''  # results/6_distillation/teacher_hidden path

        # Data loaders
        # Note: shuffle=False for efficient H5 sequential chunk access
        # (random access with 4.1TB H5 file causes severe I/O bottleneck)
        self.train_loader = DataLoader(
            DistillationDataset(self.token_dir, self.hidden_dir, 'train'),
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            drop_last=True
        )

        val_hidden_file = f'{self.hidden_dir}/val_hidden_states.h5'
        if Path(val_hidden_file).exists():
            self.val_loader = DataLoader(
                DistillationDataset(self.token_dir, self.hidden_dir, 'val'),
                batch_size=args.batch_size * 2,
                shuffle=False,
                num_workers=4,
                pin_memory=True
            )
        else:
            self.val_loader = None

        print(f"Train batches: {len(self.train_loader):,}")
        if self.val_loader:
            print(f"Val batches: {len(self.val_loader):,}")

        # Optimizer
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=args.learning_rate,
            weight_decay=0.01
        )

        # Scheduler
        total_steps = len(self.train_loader) * args.epochs
        warmup_steps = DISTILL_CONFIG['warmup_steps']

        def lr_lambda(step):
            if step < warmup_steps:
                return step / warmup_steps
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.5 * (1 + np.cos(np.pi * progress))

        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)

        # Loss functions
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
        self.mse_loss = nn.MSELoss()
        self.alpha = args.alpha

        # Training state
        self.global_step = 0
        self.best_val_loss = float('inf')
        self.training_log = []  # For logging metrics

        # Output directory with timestamp
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        # self.output_dir = Path(f'')  # distillation/distilled_model_{timestamp} path
        self.output_dir.mkdir(parents=True, exist_ok=True)
        print(f"Model save dir: {self.output_dir}")

    def compute_loss(self, distilled_logits, distilled_hidden, teacher_hidden, labels, attention_mask):
        """
        Compute combined loss:
        - CE loss on MLM predictions
        - MSE loss on hidden states
        """
        # CE loss
        ce = self.ce_loss(
            distilled_logits.view(-1, distilled_logits.size(-1)),
            labels.view(-1)
        )

        # MSE loss (only on non-padding positions)
        mask = attention_mask.unsqueeze(-1).float()  # [batch, seq, 1]
        masked_distilled = distilled_hidden * mask
        masked_teacher = teacher_hidden * mask

        mse = self.mse_loss(masked_distilled, masked_teacher)

        # Combined loss
        total_loss = self.alpha * ce + (1 - self.alpha) * mse

        return total_loss, ce.item(), mse.item()

    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0
        total_ce = 0
        total_mse = 0
        num_batches = 0

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}")

        for batch in pbar:
            input_ids = batch['input_ids'].to(self.device)
            labels = batch['labels'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            teacher_hidden = batch['teacher_hidden'].to(self.device)

            # Forward
            distilled_logits, distilled_hidden = self.model(input_ids, attention_mask)

            # Loss
            loss, ce, mse = self.compute_loss(
                distilled_logits, distilled_hidden, teacher_hidden,
                labels, attention_mask
            )

            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                DISTILL_CONFIG['gradient_clip']
            )
            self.optimizer.step()
            self.scheduler.step()

            total_loss += loss.item()
            total_ce += ce
            total_mse += mse
            num_batches += 1
            self.global_step += 1

            pbar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'ce': f"{ce:.4f}",
                'mse': f"{mse:.4f}",
                'lr': f"{self.scheduler.get_last_lr()[0]:.2e}"
            })

        avg_loss = total_loss / num_batches
        avg_ce = total_ce / num_batches
        avg_mse = total_mse / num_batches

        print(f"Epoch {epoch} - Loss: {avg_loss:.4f}, CE: {avg_ce:.4f}, MSE: {avg_mse:.4f}")
        return avg_loss, avg_ce, avg_mse

    def validate(self):
        self.model.eval()
        total_loss = 0
        total_ce = 0
        total_mse = 0
        num_batches = 0

        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation"):
                input_ids = batch['input_ids'].to(self.device)
                labels = batch['labels'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                teacher_hidden = batch['teacher_hidden'].to(self.device)

                distilled_logits, distilled_hidden = self.model(input_ids, attention_mask)

                loss, ce, mse = self.compute_loss(
                    distilled_logits, distilled_hidden, teacher_hidden,
                    labels, attention_mask
                )

                total_loss += loss.item()
                total_ce += ce
                total_mse += mse
                num_batches += 1

        self.model.train()
        return total_loss / num_batches, total_ce / num_batches, total_mse / num_batches

    def save_checkpoint(self, epoch, loss, is_best=False):
        model_state = self.model.module.state_dict() if isinstance(self.model, nn.DataParallel) else self.model.state_dict()

        checkpoint = {
            'epoch': epoch,
            'global_step': self.global_step,
            'model_state_dict': model_state,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': loss,
            'vocab_size': self.vocab_size,
            'distill_config': {
                'alpha': self.args.alpha,
                'teacher': DISTILL_CONFIG['teacher_model'],
            }
        }

        if is_best:
            torch.save(checkpoint, self.output_dir / 'best_model.pt')
            print(f"Saved best model (loss: {loss:.4f})")

        torch.save(checkpoint, self.output_dir / f'epoch_{epoch}.pt')

    def _save_training_log(self):
        """Save training log to CSV"""
        df = pd.DataFrame(self.training_log)
        log_file = self.output_dir / 'training_log.csv'
        df.to_csv(log_file, index=False)

    def train(self):
        print(f"\nStarting training for {self.args.epochs} epochs...")

        for epoch in range(1, self.args.epochs + 1):
            print(f"\n{'='*60}")
            print(f"Epoch {epoch}/{self.args.epochs}")
            print(f"{'='*60}")

            train_loss, train_ce, train_mse = self.train_epoch(epoch)

            # Log entry
            log_entry = {
                'epoch': epoch,
                'train_loss': train_loss,
                'train_ce': train_ce,
                'train_mse': train_mse,
                'lr': self.scheduler.get_last_lr()[0],
            }

            if self.val_loader:
                val_loss, val_ce, val_mse = self.validate()
                print(f"Validation - Loss: {val_loss:.4f}, CE: {val_ce:.4f}, MSE: {val_mse:.4f}")

                log_entry.update({
                    'val_loss': val_loss,
                    'val_ce': val_ce,
                    'val_mse': val_mse,
                })

                if val_loss < self.best_val_loss:
                    self.best_val_loss = val_loss
                    self.save_checkpoint(epoch, val_loss, is_best=True)
                else:
                    self.save_checkpoint(epoch, val_loss, is_best=False)
            else:
                self.save_checkpoint(epoch, train_loss, is_best=(epoch == self.args.epochs))

            self.training_log.append(log_entry)

            # Save log after each epoch
            self._save_training_log()

        # Save final model
        model_state = self.model.module.state_dict() if isinstance(self.model, nn.DataParallel) else self.model.state_dict()
        torch.save({
            'model_state_dict': model_state,
            'vocab_size': self.vocab_size,
            'distill_config': {
                'alpha': self.args.alpha,
                'teacher': DISTILL_CONFIG['teacher_model'],
            }
        }, self.output_dir / 'final_model.pt')

        print(f"\nTraining complete!")
        print(f"Best validation loss: {self.best_val_loss:.4f}")
        print(f"Model saved to: {self.output_dir}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', default='0', help='GPU IDs (e.g., "0,1,2,3")')
    parser.add_argument('--batch_size', type=int, default=DISTILL_CONFIG['batch_size'])
    parser.add_argument('--learning_rate', type=float, default=DISTILL_CONFIG['learning_rate'])
    parser.add_argument('--epochs', type=int, default=DISTILL_CONFIG['epochs'])
    parser.add_argument('--alpha', type=float, default=DISTILL_CONFIG['alpha'],
                        help='Weight for CE loss (1-alpha for MSE)')
    args = parser.parse_args()

    trainer = DistilledTrainer(args)
    trainer.train()


if __name__ == '__main__':
    main()
