from pyexpat import model
from header import *
# from nltk.translate.bleu_score import sentence_bleu
# from nltk.translate.bleu_score import SmoothingFunction
# from rouge import Rouge
from transformers import AutoTokenizer, AutoModel
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import single_meteor_score
from bert_score import score
from rouge import Rouge
import os
import json
import sys
class DeepSpeedAgent:
    
    def __init__(self, model, args):
        super(DeepSpeedAgent, self).__init__()
        self.args = args
        self.model = model
        if args['stage'] == 2:
            self.load_stage_1_parameters(args["delta_ckpt_path"])
            print(f'[!] load stage 1 checkpoint from {args["delta_ckpt_path"]}')

        # load config parameters of deepspeed
        ds_params = json.load(open(self.args['ds_config_path']))
        ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps']
        ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate']))
                # 1. 冻结 / 解冻参数（只保留 LoRA 可训练）
        for n, p in self.model.named_parameters():
            if "lora" in n.lower():  # 🔎 如果匹配不到，可以换成 "adapter" 或 "A"/"B"
                p.requires_grad = True
            else:
                p.requires_grad = False

        # # 2. Debug：打印前 50 个参数名，确认哪些解冻了
        # print("\n[Debug] Before Deepspeed init - parameter status:")
        # for n, p in self.model.named_parameters():
        #     if p.requires_grad:
        #         print(f"  ✅ {n} trainable, shape={tuple(p.shape)}")
        #     else:
        #         print(f"  ❌ {n} frozen")
        # print("-" * 60)

        # 3. 只传可训练参数给 deepspeed
        
        
        trainable_params = [p for n, p in self.model.named_parameters() if p.requires_grad]
        print(f"[DEBUG] Trainable params count: {len(trainable_params)}")

        count = sum(1 for _ in trainable_params)
        print(f"[DEBUG] Trainable params count: {count}")
        #trainable_params = [p for p in self.model.parameters() if p.requires_grad]
        #trainable_params = filter(lambda p: p.requires_grad, self.model.parameters())
        #print(f"Trainable parameters:", trainable_params[:20])  # 只打印前20个名字
        # print(f"Total trainable params: {len(trainable_params)}")
        # if len(trainable_params) == 0:
        #     raise ValueError("[Error] No trainable parameters found! Check keyword for LoRA layers.")

        # 4. 初始化 DeepSpeed Engine
        self.ds_engine, self.optimizer, _, _ = deepspeed.initialize(
            model=self.model,
            model_parameters=trainable_params,
            config_params=ds_params,
            dist_init_required=True,
            args=types.SimpleNamespace(**args)
        )
        # ✅ 紧跟在这里打印
        for i, group in enumerate(self.optimizer.param_groups):
            print(f"Group {i}, num params = {len(group['params'])}")
        # 5. Debug：确认最终有多少个参数在 optimizer 里
        #print(f"[Debug] Trainable parameter count: {len(trainable_params)}")
        # 🔽 在这里加这段 🔽
        # for n, p in self.model.named_parameters():
        #     if "lora" in n.lower():
        #         p.requires_grad = True
        #     else:
        #         p.requires_grad = False
        # for n, p in self.model.named_parameters():
        #     print(n, p.requires_grad)

        # params = [p for p in self.model.parameters() if p.requires_grad]
        # self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize(
        #     model=self.model, 
        #     #model_parameters=self.model.parameters(),
        #     model_parameters=params,
        #     config_params=ds_params, 
        #     dist_init_required=True,
        #     args=types.SimpleNamespace(**args)
        # )

    @torch.no_grad()
    def predict(self, batch):
        self.model.eval()
        string = self.model.generate(batch)
        return string
    @torch.no_grad()
    # def predict_model(self, batch,output_dir='eval_metrics',round_id=None,pbar=None):
    def predict_model(self, batch, output_dir=None, round_id=None, pbar=None):
        if output_dir is None:
            output_dir = get_eval_dir()

        self.model.eval()
        model = self.model
        eval_result = evaluate_batch(model, batch)
        all_metrics_path = os.path.join(output_dir, 'bleu_metrics_all.json')
        if round_id is not None:
            # 加载已有内容
            if os.path.exists(all_metrics_path):
                with open(all_metrics_path, 'r', encoding='utf-8') as f:
                    all_metrics = json.load(f)
            else:
                all_metrics = []

            all_metrics = {
                'round': round_id,
                'BLEU': eval_result['BLEU'],
                'ROUGE-L': eval_result['ROUGE-L'],        
                'METEOR': eval_result['METEOR'],
                'BERTScore': eval_result['BERTScore']

            }
            log_path = os.path.join(output_dir, "eval_metrics.log")

            os.makedirs(os.path.dirname(log_path), exist_ok=True)

            with open(log_path, 'a', encoding='utf-8') as f:
                f.write("[Eval Result]\n")
                for k, v in all_metrics.items():
                    f.write(f"{k}: {v}\n")
                f.write("\n")  # 换行
            #safe_json_dump(all_metrics, all_metrics_path)
            # # 写回文件
            # with open(all_metrics_path, 'w', encoding='utf-8') as f:
            #     json.dump(all_metrics, f, indent=2, ensure_ascii=False)
        # print(f"BLEU: {eval_result['BLEU']:.4f}, ROUGE-L F1: {eval_result['ROUGE-L']['f']:.4f}")

        print(f"BLEU: {eval_result['BLEU']:.4f}, ROUGE-L F1: {eval_result['ROUGE-L']['f']:.4f},'METEOR': {eval_result['METEOR']:.4f},'BERTScore_p': {eval_result['BERTScore']['P']:.4f},'BERTScore_r': {eval_result['BERTScore']['R']:.4f},'BERTScore_F1': {eval_result['BERTScore']['F1']:.4f}")

    @torch.no_grad()
    # def predict_model_1(self, batch, pbar=None,output_dir='eval_metrics', round_id=None):
    def predict_model_1(self, batch, pbar=None, output_dir=None, round_id=None):
        if output_dir is None:
            output_dir = get_eval_dir()

        self.ds_engine.module.eval()
        loss, mle_acc = self.ds_engine(batch)
        #pbar=1
        os.makedirs(output_dir, exist_ok=True)
        if pbar is not None:
            pbar.set_description(f'[Eval] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
            pbar.update(1)

        if round_id is not None:
            all_metrics_path = os.path.join(output_dir, 'loss_metrics_all.json')

            # 加载已有内容
            if os.path.exists(all_metrics_path):
                with open(all_metrics_path, 'r', encoding='utf-8') as f:
                    all_metrics = json.load(f)
            else:
                all_metrics = []

            # 添加新一轮
            all_metrics.append({
                'round': round_id,
                'loss': round(loss.item(), 6),
                'token_acc': round(mle_acc * 100, 4)
            })
            #safe_json_dump(all_metrics, all_metrics_path)
            log_path = os.path.join(output_dir, "loss_lora.log")

            os.makedirs(os.path.dirname(log_path), exist_ok=True)

            with open(log_path, 'a', encoding='utf-8') as f:
                f.write("[Eval Result]\n")
                for record in all_metrics:
                    for k, v in record.items():
                        f.write(f"{k}: {v}\n")
                    f.write("\n")  # 换行
            # 写回文件
            # with open(all_metrics_path, 'w', encoding='utf-8') as f:
            #     json.dump(all_metrics, f, indent=2, ensure_ascii=False)

        return loss.item(), mle_acc* 100


    def train_model(self, batch, current_step=0, pbar=None):
        #print("batch keys:", batch.keys())
        self.ds_engine.module.train()
        loss, mle_acc = self.ds_engine(batch)

        self.ds_engine.backward(loss)
        self.ds_engine.step()
        pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
        pbar.update(1)
        if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0:
            elapsed = pbar.format_dict['elapsed']
            rate = pbar.format_dict['rate']
            remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0
            remaining = str(datetime.timedelta(seconds=remaining))
            logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
            
        mle_acc *= 100
        return mle_acc
    
    def save_model(self, path, current_step):
        # only save trainable model parameters
        param_grad_dic = {
            k: v.requires_grad for (k, v) in self.ds_engine.module.named_parameters()
        }
        state_dict = self.ds_engine.module.state_dict()
        checkpoint = OrderedDict()
        for k, v in self.ds_engine.module.named_parameters():
            if v.requires_grad:
                checkpoint[k] = v
        torch.save(checkpoint, f'{path}/pytorch_model.pt')
        # save tokenizer
        self.model.llama_tokenizer.save_pretrained(path)
        # save configuration
        self.model.llama_model.config.save_pretrained(path)
        print(f'[!] save model into {path}')

    def load_stage_1_parameters(self, path):
        delta_ckpt = torch.load(path, map_location=torch.device('cpu'))
        self.model.load_state_dict(delta_ckpt, strict=False)
    

def evaluate_batch(model, batch):
    rouge = Rouge()
    empty_count=0
    if 'image_paths' in batch.keys():
        image_paths = batch['image_paths']
        conversations = batch['output_texts']

        batch_size = len(image_paths)
        hyps = []
        refs = []

        for i in range(batch_size):
            img_path = image_paths[i]
            conv= conversations[i]
            #conv = sample.get('conversation', [])

            # 安全检查
            if len(conv) < 2 or conv[0]['from'] != 'human' or conv[1]['from'] != 'gpt':
                print(f"[Warning] Skipping invalid conversation: {conv}")
                continue

            prompt = conv[0]['value']
            gt_answer = conv[1]['value'].strip()


            inputs = {
                'image_paths': [img_path],
                'audio_paths': [],           # 显式设置为空
                'video_paths': [],
                'thermal_paths': [],
                'prompt': prompt,
                'max_tgt_len': 128,
                'top_p': 0.9,
                'temperature': 1.0,
                'modality_embeds': [],
                'modality_cache': {}
            }
            try:
                pred_answer = model.generate(inputs).strip()
            except Exception as e:
                print(f"[Error] Generation failed for sample {i}: {e}")
                pred_answer = ""

            if pred_answer == "":
                empty_count += 1
                print(f"[Warning] Empty hypothesis at sample {i}, skipping")
                continue
            8
            refs.append(gt_answer)
            hyps.append(pred_answer)
    else:
        audio_paths = batch['audio_paths']
        conversations = batch['output_texts']

        batch_size = len(audio_paths)
        hyps = []
        refs = []

        for i in range(batch_size):
            audio_path = audio_paths[i]
            conv= conversations[i]
            #conv = sample.get('conversation', [])

            # 安全检查
            if len(conv) < 2 or conv[0]['from'] != 'human' or conv[1]['from'] != 'gpt':
                print(f"[Warning] Skipping invalid conversation: {conv}")
                continue

            prompt = conv[0]['value']
            gt_answer = conv[1]['value'].strip()


            inputs = {
                'image_paths': [],
                'audio_paths': [audio_path],           # 显式设置为空
                'video_paths': [],
                'thermal_paths': [],
                'prompt': prompt,
                'max_tgt_len': 128,
                'top_p': 0.9,
                'temperature': 1.0,
                'modality_embeds': [],
                'modality_cache': {}
            }
            try:
                pred_answer = model.generate(inputs).strip()
            except Exception as e:
                print(f"[Error] Generation failed for sample {i}: {e}")
                pred_answer = ""

            if pred_answer == "":
                empty_count += 1
                print(f"[Warning] Empty hypothesis at sample {i}, skipping")
                continue
            # print(f"[{i}] Prompt Input:\n", inputs['prompt'])
            # print(f"[{i}] GT Answer:\n", gt_answer)
            # print(f"[{i}] Predicted:\n", pred_answer)
            # exit()


            refs.append(gt_answer)
            hyps.append(pred_answer)

    if len(hyps) == 0:
        print(f"[Warning] All {batch_size} predictions are empty.")
        return {
            'BLEU': 0.0,
            'ROUGE-L': {'f': 0.0, 'p': 0.0, 'r': 0.0},
            'refs': refs,
            'hyps': hyps,
            'METEOR':0,
            'BERTScore':  {
        'P': 0.0,
        'R': 0.0,
        'F1': 0.0
    }
        }

# model_path = "/home-ssd/Users/gm_intern/lhy/PandaGPT/code/roberta-large"
# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
# model = AutoModel.from_pretrained(model_path).eval().to("cuda" if torch.cuda.is_available() else "cpu")

    P, R, F1 = score(hyps, refs, lang='en',model_type="/home-ssd/Users/gm_intern/lhy/PandaGPT/roberta-large",num_layers=17)
    def mean_skip_zeros(tensor):
        nonzero = tensor[tensor > 0]
        return round(nonzero.mean().item(), 4) if len(nonzero) > 0 else 0.0

    avg_bertscore = {
        'P': mean_skip_zeros(P),
        'R': mean_skip_zeros(R),
        'F1': mean_skip_zeros(F1)
    }

    # === BLEU ===
    chencherry = SmoothingFunction()
    bleu_scores = [
        sentence_bleu([ref.split()], hyp.split(), weights=(0.5, 0.3, 0.1, 0.1), smoothing_function=chencherry.method1)
        for ref, hyp in zip(refs, hyps)
    ]
    bleu_non = [s for s in bleu_scores if s > 0]
        # BLEU
    if len(bleu_non) == 0:
        avg_bleu = 0.0
    else:
        avg_bleu = round(sum(bleu_non) / len(bleu_non), 4)
    # avg_bleu = round(sum(bleu_non) / len(bleu_non), 4)

    # # === METEOR ===
    # meteor_scores = [
    #     single_meteor_score(ref, hyp) for ref, hyp in zip(refs, hyps)
    # ]
    # meteor_scores_non = [s for s in meteor_scores if s > 0]
    # avg_meteor = round(sum(meteor_scores_non) / len(meteor_scores_non), 4)

    # METEOR
    try:
        meteor_scores = [
            single_meteor_score(ref.split(), hyp.split())
            for ref, hyp in zip(refs, hyps)
        ]
        meteor_scores_non = [s for s in meteor_scores if s > 0]
        avg_meteor = round(sum(meteor_scores_non) / len(meteor_scores_non), 4)
    except Exception as e:
        print(f"[Warning] METEOR score failed: {e}")
        avg_meteor = 0.0

    # === ROUGE ===
    rouge = Rouge()
    rouge_scores = rouge.get_scores(hyps, refs, avg=True)
    rouge_l = {
        'p': round(rouge_scores['rouge-l']['p'], 4),
        'r': round(rouge_scores['rouge-l']['r'], 4),
        'f': round(rouge_scores['rouge-l']['f'], 4)
    }

    return {
        'BLEU': avg_bleu,
        'METEOR': avg_meteor,
        'BERTScore': avg_bertscore,
        'ROUGE-L': rouge_l,
        'refs': refs,
        'hyps': hyps,
    }

    # # === 计算 BLEU 和 ROUGE ===
    # #from bert_score import score
    # P, R, F1 = score(hyps, refs, lang='en')
    # chencherry = SmoothingFunction()
    # bleu_scores = [sentence_bleu([ref.split()], hyp.split(),weights=(0.5,0.3,0.1,0.1),smoothing_function=chencherry.method1) for ref, hyp in zip(refs, hyps)]
    # avg_bleu = sum(bleu_scores) / len(bleu_scores)
    # rouge_scores = rouge.get_scores(hyps, refs, avg=True)



    # return {
    #     'BLEU': avg_bleu,
    #     'ROUGE-L': rouge_scores['rouge-l'],
    #     'refs': refs,
    #     'hyps': hyps,
    # }


        # if round_id!=None:
        #     os.makedirs(output_dir, exist_ok=True)
        #     output_path = os.path.join(output_dir, f'metrics_round_{round_id}.json')

        #     # 提取核心指标
        #     metrics = {
        #         'round': round_id,
        #         'BLEU': eval_result['BLEU'],
        #         'ROUGE-L': eval_result['ROUGE-L']
        #     }

        #     with open(output_path, 'w', encoding='utf-8') as f:
        #         json.dump(metrics, f, indent=2, ensure_ascii=False)
def safe_json_dump(data, filepath):
    tmp_path = filepath + ".tmp"

    # 1. 先写入临时文件
    with open(tmp_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

    # 2. 原子性替换（防止中途写入失败留下空文件）
    os.replace(tmp_path, filepath)  # 原子操作，Linux 下不会产生读到空文件的情况

def get_eval_dir(base_dir="eval_metrics"):
    import sys, os
    script_name = os.path.basename(sys.argv[0])        # 获取运行的 .sh 或 .py 文件名
    script_prefix = os.path.splitext(script_name)[0]   # 去掉扩展名
    return os.path.join(base_dir, script_prefix)
