import os
import argparse
from datetime import datetime
import yaml
from omegaconf import OmegaConf

import torch
from torch.utils.data import DataLoader
from transformers import (
    Trainer, 
    TrainingArguments,
    HfArgumentParser,
    set_seed
)
from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments

from dataclasses import dataclass, field
from typing import Optional

from src.model import AnimeShooterGen
from src.utils.dataset_cogvideo_lora import Dataset_cogvideo_lora

def get_last_checkpoint(output_dir):
    if not os.path.exists(output_dir):
        return None
        
    checkpoints = [
        path for path in os.listdir(output_dir) 
        if path.startswith("checkpoint-")
    ]
    
    if not checkpoints:
        return None
        
    last_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1]))
    return os.path.join(output_dir, last_checkpoint)

class CustomTrainer(Trainer):
    def create_optimizer(self):
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if "backbone_model" in n and p.requires_grad],
                "lr": self.args.llm_learning_rate,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if "backbone_model" not in n and p.requires_grad],
                "lr": self.args.learning_rate,
            },
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters)
        return optimizer

@dataclass
class TrainingArguments(TrainingArguments):
    learning_rate: float = field(default=1e-4)
    llm_learning_rate: float = field(
        default=1e-4,
        metadata={"help": "Learning rate for LLM parameters"}
    )
    
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='src/config/model_config.yaml', help='Path to config file')
    parser.add_argument('--video_id', type=str, default='1dCd6hCRoaQ', help='Video ID for finetune')
    args = parser.parse_args()
    
    config = OmegaConf.load(args.config)
    config['data']['video_id'] = args.video_id
    config['training']['output_dir'] = os.path.join(config['training']['output_dir'], "weights", args.video_id)
    os.makedirs(config['training']['output_dir'], exist_ok=True)

    training_args = TrainingArguments(**config['training'])
    
    set_seed(training_args.seed)
    
    model = AnimeShooterGen(**config['model'])
    print("AnimeShooterGen init done...")

    model.adding_LLM_lora(config['peft'])
    if config['model']['pretrained_weight']:
        state_dict = torch.load(config['model']['pretrained_weight'], map_location='cpu')
        model.load_state_dict(state_dict, strict=False)
        print(f"Loading pretrained weights from {config['model']['pretrained_weight']}...")
    else:
        print("No pretrained weights provided...")
    
    model.prepare_trainable_parameters_cogvideo_lora(config['peft'])
    model = model.cuda()
  
    # Print trainable parameter summary
    print("\n=== Trainable Parameters Summary ===")
    trainable_params = {}
    for name, param in model.named_parameters():
        if param.requires_grad:
            top_module = name.split('.')[0]
            if top_module not in trainable_params:
                trainable_params[top_module] = 0
            trainable_params[top_module] += param.numel()
    
    # Print summary
    if trainable_params:
        max_name_len = max(len(name) for name in trainable_params.keys())
        total_params = 0
        for name, count in trainable_params.items():
            print(f"{name:<{max_name_len}} : {count:,} parameters")
            total_params += count
        print("-" * (max_name_len + 20))
        print(f"{'Total':<{max_name_len}} : {total_params:,} parameters")
    else:
        print("No trainable parameters found!")
    
    # Create dataset
    dataset = Dataset_cogvideo_lora(**config['data'])
    
    # Initialize trainer
    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=dataset.collate_fn,
    )
    
    # Print all configurations
    print("\n=== Model Arguments ===")
    print(config['model'])
    
    print("\n=== Data Arguments ===")
    print(config['data'])
    
    print("\n=== Training Arguments ===")
    for k, v in training_args.__dict__.items():
        if not k.startswith('_'):
            print(f"{k}: {v}")
    
    print("\n=== DeepSpeed Configuration ===")
    if trainer.is_deepspeed_enabled:
        print(trainer.accelerator.state.deepspeed_plugin.deepspeed_config)
    else:
        print("DeepSpeed not enabled")

    # Train
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint:
        print(f"\nFound checkpoint: {last_checkpoint}")
        print("Resuming training from this checkpoint...")
        training_args.resume_from_checkpoint = last_checkpoint
    else:
        print("\nNo checkpoint found. Starting training from scratch.")

    original_load = torch.load
    def custom_load(f, *args, **kwargs):
        return original_load(f, *args, **kwargs, weights_only=False)
    torch.load = custom_load
    trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

if __name__ == "__main__":
    main()