#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import time
import torch
import logging
import argparse
import deepspeed
from itertools import chain
import torch.optim as optim
from datetime import timedelta
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_scheduler
from models import (
    OPTForCausalLM,
    OPTSNNForCausalLM_v1, OPTSNNConfig_v1
)
from utils import (
    DistillationConfig, create_deepspeed_config,
    load_dataset_from_config,
    compute_kl_loss,
    get_num_layers, get_hidden_size, IntermediateLayerAligner, compute_inter_layer_loss, LayerAlignProjector,
    CheckpointManager,
    compute_attention_alignment_loss,
)


def is_main_process():
    if dist.is_available() and dist.is_initialized():
        return dist.get_rank() == 0

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    return local_rank == 0


def setup_logger(name="deepspeed"):
    logger = logging.getLogger(name)
    if len(logger.handlers) == 0:
        handler = logging.StreamHandler()
        formatter = logging.Formatter(
            fmt="%(asctime)s - %(levelname)s - %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S"
        )
        handler.setFormatter(formatter)
        logger.addHandler(handler)

    logger.setLevel(logging.INFO if is_main_process() else logging.CRITICAL)
    logger.propagate = False
    return logger

logger = setup_logger()


class KnowledgeDistillationTrainer:
    
    def __init__(self, config: DistillationConfig):
        self.config = config
        self.checkpoint_manager = CheckpointManager(config)
        self.current_global_step = 0
        self.current_epoch = 0
        
        self.setup_models_and_tokenizer()
        self.setup_dataset()
        self.setup_inter_layer_alignment()
        self.setup_training()

        if hasattr(config, 'resume_from_checkpoint') and config.resume_from_checkpoint:
            logger.info("resume")
    
    def setup_models_and_tokenizer(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.teacher_model_path,
            trust_remote_code=self.config.trust_remote_code
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.teacher_model = OPTForCausalLM.from_pretrained(
            self.config.teacher_model_path,
            trust_remote_code=self.config.trust_remote_code,
            torch_dtype=torch.bfloat16 if self.config.bf16 else torch.float16
        ).eval()

        logger.info(f"model path: {self.config.student_model_path}")
        if self.config.train_from_scratch:
            config = OPTSNNConfig_v1.from_pretrained(
                self.config.student_model_path,
                trust_remote_code=self.config.trust_remote_code
            )
            self.student_model = OPTSNNForCausalLM_v1(config)
        else:
            self.student_model = OPTSNNForCausalLM_v1.from_pretrained(
                self.config.student_model_path,
                trust_remote_code=self.config.trust_remote_code
            )

        self.teacher_num_layers = get_num_layers(self.teacher_model)
        self.student_num_layers = get_num_layers(self.student_model)
        
        state_dict = self.student_model.state_dict()
        total_params = 0

        for name, param in state_dict.items():
            param_count = param.numel()
            param_size_mb = param.numel() * param.element_size() / (1024 * 1024)
            total_params += param_count

        if self.config.gradient_checkpointing:
            self.student_model.gradient_checkpointing_enable()

        logger.info(self.teacher_model)
        logger.info(self.student_model)
        
        self.student_model.to('cuda:0')
    
    def setup_dataset(self):
        train_dataset = load_dataset_from_config(self.config, self.tokenizer)
        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.config.per_device_train_batch_size,
            shuffle=True,
            num_workers=self.config.dataloader_num_workers,
            collate_fn=torch.utils.data.default_collate
        )


    def setup_inter_layer_alignment(self):
        self.use_intermediate_loss = self.config.use_intermediate_loss
        self.use_embedding_loss = self.config.use_embedding_loss
        self.use_attention_loss =  self.config.use_attention_loss
        
        self.projector = None
        
        if self.use_intermediate_loss or self.use_embedding_loss or self.use_attention_loss:
            alignment_strategy = self.config.alignment_strategy
            self.intermediate_loss_type = self.config.intermediate_loss_type

            self.layer_aligner = IntermediateLayerAligner(
                logger,
                self.teacher_num_layers,
                self.student_num_layers,
                alignment_strategy
            )

            teacher_hidden_size = get_hidden_size(self.teacher_model)
            student_hidden_size = get_hidden_size(self.student_model)
            
            if teacher_hidden_size != student_hidden_size:
                self.projector = LayerAlignProjector(self.layer_aligner.layer_mapping, student_hidden_size, teacher_hidden_size, device=self.student_model.device)
            else:
                self.projector = None
        

    def setup_training(self):
        num_update_steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps
        self.max_train_steps = int(self.config.num_train_epochs * num_update_steps_per_epoch)

        if self.projector is not None:
            params_to_optimize = chain(self.student_model.parameters(), self.projector.parameters())
        else:
            params_to_optimize = self.student_model.parameters()

        optimizer = optim.AdamW(
            params_to_optimize,
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )

        ds_config = create_deepspeed_config(self.config)
        logger.info(ds_config)

        warmup_steps = int(self.max_train_steps * self.config.warmup_ratio)
        if self.config.lr_scheduler_type == 'cosine':
            ds_config["scheduler"] = {
                "type": "WarmupCosineLR",
                "params": {
                    "warmup_min_ratio": self.config.warmup_ratio,
                    "warmup_num_steps": warmup_steps,
                    "warmup_type": "log",
                    "total_num_steps": self.max_train_steps
                }
            }
        elif self.config.lr_scheduler_type == 'linear':
            ds_config["scheduler"] = {
                "type": "WarmupDecayLR",
                "params": {
                    "warmup_min_lr": self.config.min_lr,
                    "warmup_max_lr": self.config.learning_rate,
                    "warmup_num_steps": warmup_steps,
                    "total_num_steps": self.max_train_steps
                }
            }

        self.student_model, self.optimizer, _, self.lr_scheduler = deepspeed.initialize(
            model=self.student_model,
            optimizer=optimizer,
            config=ds_config
        )

        if self.lr_scheduler is None:
            self.lr_scheduler = get_scheduler(
                name="linear",
                optimizer=self.optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=self.max_train_steps
            )

        logger.info(f"gradient_accumulation_steps: {self.config.gradient_accumulation_steps}")
        logger.info(f"Using student model on: {self.student_model.device}")

        self.teacher_model = self.teacher_model.to(self.student_model.device)
        logger.info(f"Using teacher model on: {self.teacher_model.device}")
        
        

    
    def train(self):
        os.makedirs(self.config.output_dir, exist_ok=True)
        resume_global_step, resume_epoch = self.checkpoint_manager.resume_from_checkpoint(self)
        
        if resume_global_step > 0:
            self.checkpoint_manager.update_training_progress(self, resume_global_step, resume_epoch)
            global_step = resume_global_step
            start_epoch = resume_epoch
        else:
            global_step = 0
            start_epoch = 0

        loss_history = []
        
        self.student_model.train()
        
        alpha = self.config.alpha
        beta = self.config.beta
        gamma = self.config.gamma
        delta = self.config.delta
        epsilon = self.config.epsilon
        
        grad_acc_steps = self.config.gradient_accumulation_steps
        
        max_steps = (len(self.train_dataloader) // grad_acc_steps) * int(self.config.num_train_epochs)
        step_width = len(str(max_steps))
        start_train_time = time.time()
        
        loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
        
        train_start_time = time.time()
        for epoch in range(start_epoch, int(self.config.num_train_epochs)):
            epoch_iterator = self.train_dataloader
            steps_to_skip = 0

            if epoch == start_epoch and resume_global_step > 0:
                steps_per_epoch = len(self.train_dataloader) // grad_acc_steps
                total_completed_steps = resume_global_step
                completed_full_epochs = total_completed_steps // steps_per_epoch
                steps_to_skip = total_completed_steps % steps_per_epoch

            accumulated_teacher_loss = 0.0
            accumulated_student_loss = 0.0
            accumulated_distill_loss = 0.0
            accumulated_intermediate_loss = 0.0
            accumulated_embedding_loss = 0.0
            accumulation_count = 0
                
            batch_idx = 0

            for batch in epoch_iterator:
                if steps_to_skip > 0:
                    batches_to_skip = steps_to_skip * grad_acc_steps
                    if batch_idx < batches_to_skip:
                        batch_idx += 1
                        continue
                
                batch_idx += 1

                input_ids = batch["input_ids"].to(self.student_model.device)
                attention_mask = batch.get("attention_mask", None)
                if attention_mask is not None:
                    attention_mask = attention_mask.to(self.student_model.device)
                
                labels = batch["labels"].to(self.student_model.device)

                with torch.no_grad():
                    teacher_outputs = self.teacher_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        output_hidden_states=self.use_intermediate_loss,
                        output_attentions=self.use_attention_loss,
                    )
                    teacher_logits = teacher_outputs.logits
                    teacher_hidden_states = teacher_outputs.hidden_states if self.use_intermediate_loss else None
                    teacher_embeddings = self.teacher_model.model.decoder.embeddings if self.use_embedding_loss else None
                    teacher_attn = teacher_outputs.attentions if self.use_attention_loss else None

                student_outputs = self.student_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_hidden_states=self.use_intermediate_loss,
                    output_attentions=self.use_attention_loss,
                )
                
                student_logits = student_outputs.logits
                student_hidden_states = student_outputs.hidden_states if self.use_intermediate_loss else None     
                student_embeddings = self.student_model.model.decoder.embeddings if self.use_embedding_loss else None
                student_attn = student_outputs.attentions if self.use_attention_loss else None

                vocab_size = student_logits.size(-1)
                labels = labels.view(-1)

                teacher_loss = loss_fn(teacher_logits.view(-1, vocab_size), labels)
                student_loss = loss_fn(student_logits.view(-1, vocab_size), labels)
                distill_loss = compute_kl_loss(student_logits, teacher_logits, labels, self.config.temperature)

                intermediate_loss = torch.tensor(0.0, device=student_logits.device)
                if self.use_intermediate_loss and teacher_hidden_states is not None and student_hidden_states is not None:
                    if self.projector is not None:
                        projected_student_hidden_states = []
                        for i, hidden_state in enumerate(student_hidden_states):
                            key = f"layer_{i}_to_{i}"
                            if key in self.projector:
                                projected_hidden = self.projector[key](hidden_state)
                                projected_student_hidden_states.append(projected_hidden)
                            else:
                                projected_student_hidden_states.append(hidden_state)
                        student_hidden_states_for_loss = projected_student_hidden_states
                    else:
                        student_hidden_states_for_loss = student_hidden_states

                    intermediate_loss = compute_inter_layer_loss(
                        student_hidden_states_for_loss,
                        teacher_hidden_states,
                        self.layer_aligner.layer_mapping,
                        self.projector,
                        self.intermediate_loss_type,
                        self.config.temperature
                    )

                embedding_loss = torch.tensor(0.0, device=student_logits.device)
                if self.use_embedding_loss and teacher_embeddings is not None and student_embeddings is not None:
                    if teacher_embeddings.shape != student_embeddings.shape:
                        student_embeddings = student_embeddings.mean(0)
                    embedding_loss = compute_inter_layer_loss(
                            teacher_embeddings,
                            student_embeddings,
                            self.layer_aligner.layer_mapping,
                            self.projector,
                            self.intermediate_loss_type,
                            self.config.temperature
                        )
                    embedding_loss = embedding_loss * 10

                attention_loss = torch.tensor(0.0, device=student_logits.device)
                if self.use_attention_loss and teacher_attn is not None and student_attn is not None:
                    rate_mse_loss, ann_rate_loss = compute_attention_alignment_loss(
                        teacher_attn, student_attn, 
                        T=self.student_model.model.decoder.time_steps, lif_params=None, device=self.student_model.device,
                    )
                    attention_loss = (rate_mse_loss + ann_rate_loss) * 50

                student_loss_scaled = student_loss / grad_acc_steps
                distill_loss_scaled = distill_loss / grad_acc_steps
                intermediate_loss_scaled = intermediate_loss / grad_acc_steps
                embedding_loss_scaled = embedding_loss / grad_acc_steps  
                attention_loss_scaled = attention_loss / grad_acc_steps       

                accumulated_teacher_loss += teacher_loss.item()
                accumulated_student_loss += student_loss.item()
                accumulated_distill_loss += distill_loss.item()
                accumulated_intermediate_loss += intermediate_loss.item()
                accumulated_embedding_loss += embedding_loss.item()
                accumulated_attention_loss = attention_loss.item()

                accumulation_count += 1
            
                combined_loss = alpha * student_loss_scaled + \
                                        beta * distill_loss_scaled + \
                                        gamma * intermediate_loss_scaled + \
                                        delta * embedding_loss_scaled + \
                                        epsilon * attention_loss_scaled
                                        
                self.student_model.backward(combined_loss)

                if accumulation_count % grad_acc_steps == 0:

                    avg_student_loss = accumulated_student_loss / grad_acc_steps
                    avg_distill_loss = accumulated_distill_loss / grad_acc_steps
                    avg_teacher_loss = accumulated_teacher_loss / grad_acc_steps
                    avg_intermediate_loss = accumulated_intermediate_loss / grad_acc_steps
                    avg_embedding_loss = accumulated_embedding_loss / grad_acc_steps
                    avg_attention_loss = accumulated_attention_loss / grad_acc_steps
                    avg_combined_loss = alpha * avg_student_loss + \
                                                  beta * avg_distill_loss + \
                                                  gamma * avg_intermediate_loss + \
                                                  delta * avg_embedding_loss + \
                                                  epsilon * avg_attention_loss

                    grad_norm = self.student_model.get_global_grad_norm()
                    self.student_model.step()
                    self.lr_scheduler.step()
                    self.student_model.zero_grad()

                    global_step += 1
                    accumulated_teacher_loss = 0.0
                    accumulated_student_loss = 0.0
                    accumulated_distill_loss = 0.0
                    accumulated_intermediate_loss = 0.0
                    accumulated_embedding_loss = 0.0
                    accumulated_attention_loss = 0.0
                    accumulation_count = 0

                    if is_main_process():
                        if global_step % self.config.logging_steps == 0:
                            elapsed_total = time.time() - train_start_time
                            steps_per_sec = global_step / elapsed_total if elapsed_total > 0 else 0.0
                            eta_seconds = (max_steps - global_step) / steps_per_sec if steps_per_sec > 0 else 0
                            eta_string = str(timedelta(seconds=int(eta_seconds)))
                            elapsed_string = str(timedelta(seconds=int(elapsed_total)))
                            
                            grad_norm_display = f"{grad_norm:.4f}" if grad_norm is not None else "N/A"
                            logger.info(
                                f"[Step {global_step:{step_width}d} / {max_steps}]  Loss={avg_combined_loss:.3f} "
                                f"(Student: {avg_student_loss:.3f}, Teacher: {avg_teacher_loss:.3f}, Distill: {avg_distill_loss:.3f}, Inter: {avg_intermediate_loss:.3f}, Embed: {avg_embedding_loss:.3f}, Attn: {avg_attention_loss:.3f})  | "
                                f"LR={self.lr_scheduler.get_last_lr()[0]:.2e} | Grad_norm = {grad_norm_display} "
                                f"[{elapsed_string}<{eta_string}, {steps_per_sec:.2f} Steps/s]"
                            )

                        loss_history.append({
                            'step': global_step,
                            'loss': avg_combined_loss,
                            'distill_loss': avg_distill_loss,
                            'student_loss': avg_student_loss,
                            'teacher_loss': avg_teacher_loss,
                            'inter_loss': avg_intermediate_loss,
                            'embedding_loss': avg_embedding_loss,
                            'attention_loss': avg_attention_loss,
                            'lr': self.lr_scheduler.get_last_lr()[0]
                        })

                        if global_step % (self.config.logging_steps * 50) == 0:
                            log_path = os.path.join(self.config.output_dir, "loss_history.jsonl")
                            with open(log_path, "a") as f:
                                for record in loss_history:
                                    f.write(json.dumps(record) + "\n")
                            loss_history.clear()

                        if global_step % self.config.save_steps == 0:
                            self.save_checkpoint(global_step)

            self.current_epoch = epoch + 1

        if is_main_process() and loss_history:
            log_path = os.path.join(self.config.output_dir, "loss_history.jsonl")
            with open(log_path, "a") as f:
                for record in loss_history:
                    f.write(json.dumps(record) + "\n")
            loss_history.clear()   
        
        total_train_time = time.time() - start_train_time
        self.save_checkpoint(global_step, is_final=True)


    def save_checkpoint(self, global_step, is_final=False):
        save_path = self.config.output_dir if is_final else os.path.join(self.config.output_dir, f"checkpoint-{global_step}")
        os.makedirs(save_path, exist_ok=True)

        if hasattr(self.student_model, 'module'):
            model_to_save = self.student_model.module
        elif hasattr(self.student_model, 'model'):
            model_to_save = self.student_model.model
        else:
            model_to_save = self.student_model

        try:
            model_to_save.save_pretrained(
                save_path,
                safe_serialization=True,
                max_shard_size="5GB"
            )

        except Exception as e:
            try:
                from safetensors.torch import save_file
                state_dict = model_to_save.state_dict()
                save_file(state_dict, os.path.join(save_path, "model.safetensors"))
                if hasattr(model_to_save, "config"):
                    model_to_save.config.to_json_file(os.path.join(save_path, "config.json"))
            except Exception as e2:
                logger.error(e2)
                return False

        self._save_additional_components(save_path, global_step, is_final)
        return save_path


    def _save_additional_components(self, save_path, global_step, is_final):

        if hasattr(self, 'tokenizer') and self.tokenizer is not None:
            try:
                self.tokenizer.save_pretrained(save_path)
            except Exception as e:
                logger.error(e)

        try:
            model_to_check = self.student_model.module if hasattr(self.student_model, 'module') else self.student_model
            if hasattr(model_to_check, 'config'):
                config_path = os.path.join(save_path, "config.json")
                with open(config_path, 'w', encoding='utf-8') as f:
                    config_dict = model_to_check.config.to_dict() if hasattr(model_to_check.config, 'to_dict') else vars(model_to_check.config)
                    json.dump(config_dict, f, indent=2, ensure_ascii=False)
        except Exception as e:
            logger.error(e)

        try:
            training_state = {
                'global_step': global_step,
                'is_final': is_final,
            }
            if hasattr(self, 'current_epoch'):
                training_state['epoch'] = self.current_epoch
            
            with open(os.path.join(save_path, "training_state.json"), 'w') as f:
                json.dump(training_state, f, indent=2)
        except Exception as e:
            logger.error(e)


def main():
    parser = argparse.ArgumentParser(description="KD")
    parser.add_argument(
        "--config",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
    )

    args = parser.parse_args()

    config = DistillationConfig(args.config)
    trainer = KnowledgeDistillationTrainer(config)
    trainer.train()

if __name__ == "__main__":
    main()