import os
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,AutoModelForCausalLM,AutoConfig,BitsAndBytesConfig,
    get_scheduler,T5EncoderModel,T5Tokenizer,
)   
import bitsandbytes.optim as bnb_optim

from tools import print_trainable_parameters, str2bool, set_seed
from train import train_epoch
from data import load_and_concat_datasets, MultiTaskSamplerDataLoader
from moelayer import MoELoRA, MoELoRAQwen, WrappedLoRALayer


MODEL_PATHS = { }

def load_model(args, task_embeddings, model_name):
    if args.use_q:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16,
        )
        print("You have used the bnb config!!")
    else:
        bnb_config = None
        print("You use the vanilla model")

    config = AutoConfig.from_pretrained(
        MODEL_PATHS[model_name], trust_remote_code=True, device_map="auto"
    )
    config.parallel_model = "simple"

    model = MoELoRAQwen.from_pretrained(
        MODEL_PATHS[model_name],
        config=config,
        num_experts=args.num_experts,
        rank=args.rank,
        top_k=args.top_k,
        alpha=args.alpha,
        task_embeddings=task_embeddings,
        use_lora_dropout=args.use_lora_dropout,
        quantization_config=bnb_config,
        trust_remote_code=True,
        device_map="auto",
    )

    model.apply_film()
    model.apply_moelora()
    for module in model.modules():
        if isinstance(module, WrappedLoRALayer) or isinstance(module, MoELoRA):
            device = next(module.original_module.parameters()).device
            module.to(device)
    model.film_adapter.to(device)

    return model


def main(args):
    set_seed(args.seed)
    model_name = args.model_name
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATHS[model_name], trust_remote_code=True)
    t5_encoder = T5EncoderModel.from_pretrained()
    t5_tokenizer = T5Tokenizer.from_pretrained()
    selected_tasks = args.selected_tasks.split(",")
    train_dataset,valid_datasets,task_embeddings = load_and_concat_datasets(
        selected_tasks,
        tokenizer = tokenizer,
        t5_encoder = t5_encoder,
        t5_tokenizer=t5_tokenizer,
        max_length = args.max_length,
        label_max_length = args.label_length,
        samples_per_task = 20,
    )
    print(f"The shape of task_embeddings is {task_embeddings.shape}")
    print("Load data successfully")
    train_loader = MultiTaskSamplerDataLoader(train_dataset, batch_size=args.batch_size)
    print("Dataloader successfully created")

    moe = load_model(args,task_embeddings,model_name)
    print_trainable_parameters(moe)
    print(moe.film_adapter)

    use_qadamw = False
    if use_qadamw:
        optimizer = bnb_optim.AdamW8bit( 
            [p for p in moe.parameters() if p.requires_grad],
            lr=args.learning_rate,
            weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.AdamW(  # type: ignore
            [p for p in moe.parameters() if p.requires_grad],
            lr=args.learning_rate,
            weight_decay=args.weight_decay)
    
    num_epochs = args.num_epochs
    num_training_steps = len(train_loader) * num_epochs
    num_warmup_steps = args.warmup_steps
    print(f"Total training steps: {num_training_steps}")
    print(f"Warmup steps: {num_warmup_steps}")

    lr_scheduler = get_scheduler(
        name="cosine", 
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )   

    device = moe.device
    for epoch in range(num_epochs):
        avg_task_loss, avg_aux_loss, avg_orth_loss = train_epoch(moe, tokenizer, train_loader, optimizer, lr_scheduler,device, selected_tasks,
                                valid_datasets,valid_batchsize=args.valid_batch_size,label_length=args.label_length,
                                beta=args.beta,lambda_1=args.lambda_1,save_steps=args.save_steps, save_path="./checkpoints", use_aux_loss=args.use_aux_loss, use_orth_loss=args.use_orth_loss)
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"    Average Task Loss: {avg_task_loss:.4f}")



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train MoELoRA on Qwen3")

    parser.add_argument("--model_name",type=str,choices=["Qwen/Qwen3-0.6B","Qwen/Qwen3-4B","Qwen/Qwen3-8B"],default="Qwen/Qwen3-4B")
    parser.add_argument("--num_experts", type=int, default=8)
    parser.add_argument("--rank", type=int, default=8)
    parser.add_argument("--top_k", type=int, default=2)
    parser.add_argument("--alpha", type=int, default=16)
    parser.add_argument("--use_q", type=str2bool, default=False)
    parser.add_argument("--use_lora_dropout", type=str2bool, default=False)

    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--num_epochs", type=int, default=1)
    parser.add_argument("--warmup_steps", type=int, default=500)
    parser.add_argument("--use_aux_loss", type=str2bool, default=True)
    parser.add_argument("--use_orth_loss", type=str2bool, default=False)
    parser.add_argument("--beta", type=float, default=0.005)
    parser.add_argument("--lambda_1", type=float, default=0.01)
    parser.add_argument("--save_steps", type=int, default=1000)
    parser.add_argument("--seed", type=int, default=2025)

    parser.add_argument("--selected_tasks", type=str, default="rte,mrpc")
    parser.add_argument("--max_length", type=int, default=128)
    parser.add_argument("--label_length", type=int, default=32)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--valid_batch_size", type=int, default=16)

    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    args = parser.parse_args()
    print(args)
    main(args)