from transformers import LlamaModel, AutoTokenizer, AutoModel, Trainer, default_data_collator, TrainerCallback, TrainingArguments
from contextlib import nullcontext
import torch
import torch.distributed as dist
import numpy as np
from utils.args import Arguments
from utils.dist import is_dist, set_dist_env
from utils.metrics import accuracy
from utils.peft import create_lora_config
from utils.utils import model_id
from models.LMs import BertClassifier, LlamaClassifier, HierBertClassifier, CatBertClassifier, AH_BertClassifier, AttBertClassifier, LabelBertClassifier, LAH_BertClassifier, TestLabelBertClassifier, LAttBertClassifier, OLAttBertClassifier, SelfLabelBertClassifier
from data.load import load_data
from data.dataset import NCDataset, A_NCDataset, H_NCDataset, L_NCDataset, S_NCDataset
from data.sampling import collect_subgraphs, ego_graphs_sampler, constrained_ego_graphs_sampler
import os
from copy import deepcopy

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def collect_txt(idx, txt):
    tmp = []
    for i in idx:
        tmp.append(txt[i])
    return tmp

if __name__ == '__main__':
    if is_dist():
        rank = set_dist_env()
    else:
        rank = 0
     
    config = Arguments().parse_args()
    print(config)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    output_dir = f"tmp"
    epochs = config.epochs
    enable_profiler = False
    
    # Set up profiler
    if enable_profiler:
        wait, warmup, active, repeat = 1, 1, 2, 1
        total_steps = (wait + warmup + active) * (1 + repeat)
        schedule =  torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat)
        profiler = torch.profiler.profile(
            schedule=schedule,
            on_trace_ready=torch.profiler.tensorboard_trace_handler(f"{output_dir}/logs/tensorboard"),
            record_shapes=True,
            profile_memory=True,
            with_stack=True)
        
        class ProfilerCallback(TrainerCallback):
            def __init__(self, profiler):
                self.profiler = profiler
                
            def on_step_end(self, *args, **kwargs):
                self.profiler.step()

        profiler_callback = ProfilerCallback(profiler)
    else:
        profiler = nullcontext()
    
    acc_list = []
    best_val_acc = 0
    
    for run in range(config.runs):
        
        data, text, num_classes = load_data(config.dataset, use_text=True, seed=run)

        # Load model from HuggingFace Hub
        if config.lm_type == 'llama':
            tokenizer = AutoTokenizer.from_pretrained(model_id[config.lm_type], use_fast=False, trust_remote_code=True)
            bert_model = LlamaModel.from_pretrained(model_id[config.lm_type], device_map=rank if is_dist() else device, torch_dtype=torch.float16)
            tokenizer.sep_token = tokenizer.bos_token
            bert_model.config.sep_token_id = bert_model.config.bos_token_id
            tokenizer.pad_token = tokenizer.eos_token
            bert_model.config.pad_token_id = bert_model.config.eos_token_id
        else:
            # bert_model = AutoModel.from_pretrained(model_id[config.lm_type], output_hidden_states=True, return_dict=True)
            # tokenizer = AutoTokenizer.from_pretrained(model_id[config.lm_type])
            model_path = f"../models/{config.lm_type}"
            tokenizer = AutoTokenizer.from_pretrained( 
                model_path,
                local_files_only=True  # 强制只使用本地文件
            )
            token_model = AutoModel.from_pretrained( 
                model_path,
                output_hidden_states=True,
                output_attentions=True,
                return_dict=True,
                local_files_only=True  # 强制只使用本地文件
            ).to(f"cuda:0")
            # node_model = AutoModel.from_pretrained( 
            #     model_path,
            #     output_hidden_states=True,
            #     return_dict=True,
            #     local_files_only=True  # 强制只使用本地文件
            # ).to(f"cuda:1")
        
        train_idx = data.train_mask.nonzero().squeeze()
        val_idx = data.val_mask.nonzero().squeeze()
        test_idx = data.test_mask.nonzero().squeeze() 

        print(train_idx.shape)
        print(train_idx)

        graphs = constrained_ego_graphs_sampler(torch.arange(data.num_nodes), data, 40)
        #graphs = ego_graphs_sampler(torch.arange(data.num_nodes), data, hop=2)
        #graphs = collect_subgraphs(torch.arange(data.num_nodes), data, walk_steps=config.walk_steps, restart_ratio=config.restart)
        dataset = S_NCDataset(graphs, data.y, tokenizer, token_model, text, config)
        #dataset = H_NCDataset(graphs, data.y, tokenizer, text, config)
        train_dataset = torch.utils.data.Subset(dataset, train_idx)
        val_dataset = torch.utils.data.Subset(dataset, val_idx)
        test_dataset =  torch.utils.data.Subset(dataset, test_idx)

        root_encodings = dataset.root_encodings

        device = token_model.device

        root_encodings = torch.stack(root_encodings)
        train_root_encodings = root_encodings[train_idx].to(device)  # [num_train, hidden_dim]                  # [num_train]
        train_labels = data.y[train_idx].to(device)

        # 初始化原型矩阵
        class_prototypes = torch.zeros(num_classes, train_root_encodings.size(1), device=device)

        # 使用 scatter_add_ 来累加每个类别的 embedding
        class_prototypes.scatter_add_(
            dim=0,
            index=train_labels.unsqueeze(-1).expand_as(train_root_encodings),
            src=train_root_encodings
        )

        # 统计每个类别的数量
        _, class_counts = train_labels.unique(return_counts=True)

        # 避免除以 0
        class_counts = class_counts.float().clamp(min=1)

        # 每个类别的平均 embedding
        class_prototypes = class_prototypes / class_counts.unsqueeze(-1)
        
        if config.lora:
            print("lora_activated")
            bert_model, _ = create_lora_config(bert_model, config.rank)
            if config.lm_type == 'llama':
                for param in bert_model.parameters():
                    if param.requires_grad:
                        param.data = param.data.float()
        
        
        if config.lm_type == 'llama':
            model = LlamaClassifier(bert_model, num_classes)
        else:
            #model = BertClassifier(bert_model, num_classes)
            #model = LabelBertClassifier(token_model=token_model, n_labels=num_classes, label_text=label_text)
            #model = LAttBertClassifier(token_model=token_model, n_labels=num_classes, label_text=label_text)
            # model = LAH_BertClassifier(token_model=token_model, node_model=node_model, n_labels=num_classes, label_text=label_text)
            model = SelfLabelBertClassifier(token_model=token_model, n_labels=num_classes, label_initial=class_prototypes)
            #model = TestLabelBertClassifier(token_model=token_model, n_labels=num_classes, label_text=label_text)
        if is_dist():
            dist.barrier()
            
        # Define training args
        training_args = TrainingArguments(
            output_dir=output_dir,
            overwrite_output_dir=True,
            report_to="none",
            bf16=False,  # Use BF16 if available
            dataloader_pin_memory=False,
            # logging strategies
            logging_dir=f"{output_dir}/logs",
            logging_strategy="epoch",
            logging_steps=1,
            save_strategy="no",
            optim="adamw_torch_fused",
            max_steps=total_steps if enable_profiler else -1,
            learning_rate=config.lr,
            num_train_epochs=epochs,
            gradient_accumulation_steps=2,
            per_device_train_batch_size=config.batch_size,
            gradient_checkpointing=False,
            local_rank=rank if is_dist() else -1,
        )

        with profiler:
            # Create Trainer instance
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=val_dataset,
                data_collator=default_data_collator,
                callbacks=[profiler_callback] if enable_profiler else [],
            )

            # Start training
            trainer.train()

        # evaluation
        predictions = trainer.predict(dataset)
        print(predictions.predictions[0].shape)
        train_predictions, train_labels = predictions.predictions[0][train_idx], predictions.label_ids[train_idx]
        val_predictions, val_labels = predictions.predictions[0][val_idx], predictions.label_ids[val_idx]
        test_predictions, test_labels = predictions.predictions[0][test_idx], predictions.label_ids[test_idx]
        
        if not is_dist() or rank == 0:
            # report acc
            train_acc = accuracy(train_predictions, train_labels)
            test_acc = accuracy(test_predictions, test_labels)
            val_acc = accuracy(val_predictions, val_labels)
            
            print(f"# lr : {config.lr}  run : {run} , train acc : {train_acc * 100:.2f}, val acc : {val_acc * 100:.2f} , test acc : {test_acc * 100:.2f}")
            acc_list.append(test_acc)
          
            embedding = predictions.predictions[1]
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_prediction = deepcopy(embedding)

        # if not is_dist() or rank == 0:
        #     out_dir = os.path.join('../models', 'lm_lora', f"{config.lm_type}")
        #     os.makedirs(out_dir, exist_ok=True)
        #     bert_model.save_pretrained(out_dir)   # 保存LoRA权重 
        #     #tokenizer.save_pretrained(out_dir,  save_only_config=True)    # 保存tokenizer 

        
        # clear cache
        del predictions
        del trainer
        del model
        del tokenizer
        #del bert_model
        del token_model
        #del node_model
        
        if is_dist():
            dist.barrier()
        
            
    if not is_dist() or rank == 0:
        final_acc, final_acc_std = np.mean(acc_list), np.std(acc_list)
        print(f"final_acc: {final_acc * 100:.2f} ± {final_acc_std * 100:.2f}")
        out_dir = os.path.join('out', 'lm_emb', f"{config.lm_type}_logits")
        os.makedirs(out_dir, exist_ok=True)
        
        torch.save(torch.tensor(best_prediction), os.path.join(out_dir, f'{config.dataset}.pt'))