################
#script to augment features with CLIP

import pickle
import os
import clip
import torch
import network
import torch.nn as nn
from utils.stats import calc_mean_std
import argparse
from main import get_dataset
from torch.utils import data
import numpy as np
import random
from clip.simple_tokenizer import SimpleTokenizer  # 放在文件开头 import
from torch.utils import data
from utils.freeze import freeze_all
from torch.nn.functional import fold, unfold
from utils.PPIN import PPIN 
from datasets import Cityscapes, gta5, endovis18, nalendovis

from torch.utils.tensorboard import SummaryWriter

def compose_text_with_templates(text: str, templates) -> list:
    return [template.format(text) for template in templates]

imagenet_templates = [
    'A laparoscopic surgery image showing {}',
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]


def get_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--lr", type=float, default=1, help='optimization step')
    parser.add_argument("--gpu_id", type=str, default='0',
                        help="GPU ID")
    parser.add_argument("--data_root", type=str, default='./datasets/data',
                        help="path to dataset")
    parser.add_argument("--save_dir", type=str, 
                        help= "path for learnt parameters saving")
    parser.add_argument("--dataset", type=str, default='cityscapes',
                        choices=['cityscapes','gta5','endovis18'], help='Name of dataset')
    parser.add_argument("--crop_size", type=int, default=768)
    parser.add_argument("--batch_size", type=int, default=2,
                        help='batch size (default: 16)')
    parser.add_argument("--div", type=int, default=3)
    available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \
                              not (name.startswith("__") or name.startswith('_')) and callable(
                              network.modeling.__dict__[name])
                              )
    parser.add_argument("--model", type=str, default='deeplabv3plus_resnet_clip',
                        choices=available_models, help='model name')
    parser.add_argument("--BB", type=str, default = 'RN50',
                        help= "backbone name" )
    parser.add_argument("--weight_decay", type=float, default=1e-4,
                        help='weight decay (default: 1e-4)')
    parser.add_argument("--total_it", type = int, default =100,
                        help= "total number of optimization iterations")
    # learn statistics
    parser.add_argument("--resize_feat",action='store_true',default=False,
                        help="resize the features map to the dimension corresponding to CLIP")
    # random seed
    parser.add_argument("--random_seed", type=int, default=1,
                        help="random seed (default: 1)")
    # target domain description
    parser.add_argument("--domain_desc", type=str , default = "driving at night.",
                        help = "description of the target domain")


    return parser



def main():

    opts = get_argparser().parse_args()

    device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

    # os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id

    # INIT
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    train_dst,val_dst = get_dataset(opts.dataset,opts.data_root,opts.crop_size,data_aug=False)

    train_loader = data.DataLoader(
        train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=0,
        drop_last=False)  # drop_last=True to ignore single-image batches.
    
    print("Dataset: %s, Train set: %d, Val set: %d" %
        (opts.dataset, len(train_dst), len(val_dst)))

    model = network.modeling.__dict__[opts.model](num_classes=8,BB= opts.BB,replace_stride_with_dilation=[False,False,False])

    for p in model.backbone.parameters():
        p.requires_grad = False
    model.backbone.eval()
    model.to(device)

    clip_model, preprocess = clip.load(opts.BB, device, jit=False)

    cur_itrs = 0
    writer = SummaryWriter()
    
    if not os.path.isdir(opts.save_dir):
        os.mkdir(opts.save_dir)
    if opts.resize_feat:
        t1 = nn.AdaptiveAvgPool2d((56,56))
    else:
        t1 = lambda x:x


    #text
    #target text
    context_target = compose_text_with_templates(opts.domain_desc, imagenet_templates)

    context_tokens = clip.tokenize(context_target).to(device)
    context_text_target = clip_model.encode_text(context_tokens).mean(axis=0, keepdim=True).detach()
    context_text_target /= context_text_target.norm(dim=-1, keepdim=True)
    context_text_target = context_text_target.repeat(opts.batch_size,1).type(torch.float32)  # (B,1024)
    text_embedding_domain = context_text_target.clone()  # 复制为 domain embedding

    patch_num = 9
    # context_dir = '/seu_share/home/220242341/dataset/endovis18_best_20250325160233/endovis18_best/image_descriptions_classes_simple50'
    # context_dir = '/mnt/sdc/zhangdong/code/HRDA/data_me/endovis18_best/image_descriptions_noclasses_simple50/image_descriptions_noclasses_simple50'
    context_dir = '/mnt/sdc/zhangdong/code/HRDA/data_me/image_descriptions_cot/'

    stats_path = os.path.join(opts.save_dir, 'category_style_stats_all.pkl')
    finished_txt = os.path.join(opts.save_dir, 'finished.txt')

    # ✅ 加载已完成图像列表
    if os.path.exists(finished_txt):
        with open(finished_txt, 'r') as f:
            finished_imgs = set(line.strip() for line in f)
    else:
        finished_imgs = set()

    

    
    for i,(img_id, tar_id, images, labels) in enumerate(train_loader):
            print(i)   

            batch_img_names = [os.path.splitext(os.path.basename(p))[0] for p in img_id]

            if all(name in finished_imgs for name in batch_img_names):
                print(f"✅ Skipping batch {i}: already processed.")
                continue


            # 加载已有 stats（如果存在）
            if os.path.exists(stats_path):
                with open(stats_path, 'rb') as f:
                    stats = pickle.load(f)
            else:
                stats = {
                    'Background Tissue_mu': [],
                    'Bipolar Forceps_mu': [],
                    'Grasper_mu': [],
                    'Large Needle Driver_mu': [],
                    'Monopolar Curved Scissors_mu': [],
                    'Ultrasound Probe_mu': [],
                    'Suction Irrigator_mu': [],
                    'Clip Applier_mu': [],
                    '255_mu': [],
                    'Background Tissue_std': [],
                    'Bipolar Forceps_std': [],
                    'Grasper_std': [],
                    'Large Needle Driver_std': [],
                    'Monopolar Curved Scissors_std': [],
                    'ultrasound probe_std': [],
                    'Ultrasound Probe_std': [],
                    'Clip Applier_std': [],
                    '255_std': [],
                }


            tokenizer_clip = SimpleTokenizer()
            context_texts = []

            for k in range(len(img_id)):
                img_basename = os.path.splitext(os.path.basename(img_id[k]))[0]
                txt_path = os.path.join(context_dir, f"{img_basename}.txt")


                with open(txt_path, 'r') as f:
                    context = f.read().strip()

                # 预编码为 token 列表
                tokens = tokenizer_clip.encode(context)

                # 截断到 CLIP 的最大长度
                if len(tokens) > 70:
                    tokens = tokens[:70]
                    context = tokenizer_clip.decode(tokens)

                context_texts.append(context)

            # 再做 tokenize
            tokens_context = clip.tokenize(context_texts).to(device)

            text_embedding_context = clip_model.encode_text(tokens_context).detach()
            text_embedding_context /= text_embedding_context.norm(dim=-1, keepdim=True)  # (B, 1024)

            # ✅ 插入补丁，确保 batch size 一致
            if text_embedding_domain.shape[0] != text_embedding_context.shape[0]:
                text_embedding_domain = text_embedding_domain[:text_embedding_context.shape[0]]

            
            # 融合嵌入
            alpha = 0.1  # 权重系数，可调
            context_text_target = alpha * text_embedding_context + (1 - alpha) * text_embedding_domain
            context_text_target /= context_text_target.norm(dim=-1, keepdim=True)  # (B, 1024) 
            
            #torch.Size([16, 256, 192, 192])
            f1 = model.backbone(images.to(device), trunc1=False, trunc2=False,
                    trunc3=False, trunc4=False, get1=True, get2=False, get3=False, get4=False)  # (B,C1,H1,W1)
            
            labels_ = labels.unsqueeze(1).to(torch.float32)  # 变为 (b, 1, H, W)

            B, _, H, W = labels_.shape
            div = opts.div
            
            #768/8 = 96, 768/8 = 96
            patch_H, patch_W = H // div, W // div   
            
            labels_ = labels.unsqueeze(1).to(torch.float32)  # (16, 1, 768, 768)
            
            #torch.Size([16, 36864, 16])
            lbl_patches = unfold(labels_, kernel_size=patch_H, stride=patch_H)  # (16, 1 * patch_H * patch_W, N)
            
            #torch.Size([2, 8*8, 1, 96, 96])
            lbl_patches = lbl_patches.permute(0, 2, 1).reshape(B, div * div, 1, patch_H, patch_W)  # (B, 64, 1, patch_H, patch_W)
            # print(lbl_patches.shape)  torch.Size([2, 64, 1, 96, 96])
            
            selected_patch_indices = []  # (B, N_patches)
            text_targets = []            # (B*N_patches, 1024)
            selected_classes_names = []
            
            for b in range(B):
                ind = []
                targets = []
                classes_name = []

                for j in range(div * div):  # 遍历所有 patch
                    patch = lbl_patches[b, j].flatten()
                    mode_class = int(torch.mode(patch).values)

                    ind.append(j)  # 选择当前 patch

                    if mode_class == 255:
                        target = "photo"  # 对于 255 类的 patch，使用默认描述
                        classes_name.append(target)
                    else:
                        class_name = endovis18.Endovis18DataSet.name(mode_class)
                        classes_name.append(class_name)
                        target = class_name + ' under dim lighting conditions with blood splatter.'

                    target = compose_text_with_templates(target, imagenet_templates)

                    tokens = clip.tokenize(target).to(device)
                    text_feat = clip_model.encode_text(tokens).mean(dim=0, keepdim=True).detach()
                    text_feat /= text_feat.norm(dim=-1, keepdim=True)

                    targets.append(text_feat)

                unique = list(set(classes_name))  # 去重，获得图像中出现的类别（不重复）
                # print("unique",unique)
                # print("classes_name",classes_name)
                # 保存所有 patch 和对应的文本目标
                selected_patch_indices.append(ind)# (B, N_patches)
                selected_classes_names.append(classes_name)# (B, N_patches)
                text_targets.extend(targets)


            # print(len(selected_patch_indices[0])) #64
            # text_targets: shape = (B*10, 1024)

            text_targets = torch.cat(text_targets, dim=0).to(device)
            # print(text_targets.shape) #torch.Size([2*64, 1024])


            
            # #optimize mu and sigma of target features with CLIP
            # model_pin_1 = PIN([f1.shape[0],256,1,1],f1.to(device)) #  mu_T (B,C1)  sigma_T(B,C1)
            # model_pin_1.to(device)


            # optimizer_pin_1 = torch.optim.SGD(params=[
            #     {'params': model_pin_1.parameters(), 'lr': 1},
            # ], lr= 1, momentum=0.9, weight_decay=opts.weight_decay)
            
            #optimize mu and sigma of target features with CLIP
            #div=3 → 每张图切成 3×3 patch
            #ind=[0,3,7] → 当前 batch 中选的 patch index
            #PIN 内部会提取 len(ind) 个 patch，作为需要被优化的目标
            # print(len(selected_patch_indices)) 16
            # print(len(selected_patch_indices[0])) 10
            model_ppin = PPIN(f1.to(device),div=opts.div,ind=selected_patch_indices)
           
            model_ppin.to(device)

            optimizer_ppin1 = torch.optim.SGD(params=[
                {'params': model_ppin.parameters(), 'lr': opts.lr},
            ], lr= opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
            

            if i == len(train_loader)-1 and f1.shape[0] < opts.batch_size :
                context_text_target = context_text_target[:f1.shape[0]]

            while cur_itrs< opts.total_it: 

                cur_itrs += 1
                if cur_itrs % opts.total_it==0:
                    print(cur_itrs)

                # optimizer_pin_1.zero_grad()
                optimizer_ppin1.zero_grad()
            
                # f1_hal = model_pin_1()
                # f1_hal_trans = t1(f1_hal)
                # 获取优化后的 patch 特征
                #torch.Size([2*64, 256, 24, 24])
                patches_low_hal_ = model_ppin()  # shape: (B*64, C, H_patch, W_patch)
                # torch.Size([2*64, 256, 24, 24])
                
                patches_low_hal = t1(patches_low_hal_)  # 若有 resize
                # torch.Size([16*64, 256, 48, 48])
                

                # Step 1: 展平 f1 原始特征图
                # f1.shape: (B, C, H, W)，例如 (16, 256, 192, 192)
                B, C, H, W = f1.shape
                patch_H = H // opts.div  # 每个 patch 的高度，例如 24 = 192//8
                patch_W = W // opts.div # 每个 patch 的高度，例如 24 = 192//8
                f1_unfold = unfold(f1, kernel_size=(patch_H, patch_W), stride=(patch_H, patch_W))  # shape: (16, 256*24*24, 64)
                f1_unfold = f1_unfold.permute(0, 2, 1)  # shape: (16, 64, C*patch_H*patch_W)
                # print(f1_unfold.shape) #torch.Size([16, 64, 589824])

                # Step 2: 替换每张图像中 所有的  patch
                for b in range(B):
                    for i in range(patch_num):
                        patch_idx = selected_patch_indices[b][i]  # 第 b 张图像中的第 i 个被选中的 patch 下标，范围 0~63
                        patch_feat = patches_low_hal_[b*patch_num + i]  # shape: (C, patch_H, patch_W)
                        # print(patch_feat.shape) torch.Size([256, 56, 56])
                        f1_unfold[b, patch_idx] = patch_feat.flatten()

                # Step 3: 重构为完整特征图
                f1_unfold = f1_unfold.permute(0, 2, 1)  # shape: (B, C*patch_H*patch_W, div*div)
                f1_new = fold(f1_unfold, output_size=(H, W), kernel_size=(patch_H, patch_W), stride=(patch_H, patch_W))  # shape: (B, C, H, W)
                
                f1_hal_trans = t1(f1_new)
                # 此时 f1_new 就是包含局部增强 patch 的整图特征，可以用于后续训练：
                # 例如：
                # out = model.classifier(f1_new)


                #target_features (hallucinated)
                #PATCH 送入主干提取最终语义特征,输出 shape: (len(ind), 1024)（与CLIP文本空间对齐）
                target_features_from_low = model.backbone(patches_low_hal.to(device),trunc1=True,trunc2=False,
                trunc3=False,trunc4=False,get1=False,get2=False,get3=False,get4=False)
                target_features_from_low /= target_features_from_low.norm(dim=-1, keepdim=True).clone().detach()

                loss = (1- torch.cosine_similarity(text_targets, target_features_from_low, dim=1)).mean()
                
                #target_features (optimized)
                target_features_from_f1 = model.backbone(f1_hal_trans,trunc1=True,trunc2=False,trunc3=False,trunc4=False,get1=False,get2=False,get3=False,get4=False)
                target_features_from_f1 /= target_features_from_f1.norm(dim=-1, keepdim=True).clone().detach()
       
                #loss
                loss_CLIP1 = (1- torch.cosine_similarity(context_text_target, target_features_from_f1, dim=1)).mean()

                writer.add_scalar("loss_CLIP_f1"+str(i),loss_CLIP1,cur_itrs)
                loss_all = loss + loss_CLIP1 * 0.5
                print("loss",loss)
                # print("loss_CLIP1",loss_CLIP1 * 0.1)
                loss.backward(retain_graph=True)
              
                # optimizer_pin_1.step()
                optimizer_ppin1.step()
        
            cur_itrs = 0
            
            # 提取 style_mean 和 style_std
            for name, param in model_ppin.named_parameters():
                if name == 'style_mean':
                    learnt_mu_f1 = param.data
                elif name == 'style_std':
                    learnt_std_f1 = param.data

            # 累积统计
            for k in range(B):
                for i in range(patch_num):
                    mu_patch = learnt_mu_f1[k * patch_num + i].detach().cpu()
                    std_patch = learnt_std_f1[k * patch_num + i].detach().cpu()

                    class_name = selected_classes_names[k][i]
                    mu_key = f"{class_name}_mu"
                    std_key = f"{class_name}_std"

                    if mu_key not in stats:
                        stats[mu_key] = []
                    if std_key not in stats:
                        stats[std_key] = []

                    stats[mu_key].append(mu_patch)
                    stats[std_key].append(std_patch)
            # 保存 stats 到文件
            with open(stats_path, 'wb') as f:
                pickle.dump(stats, f)
            print(f"✅ Saved stats after batch {i} to {stats_path}")

            # 记录本 batch 的图像为已完成
            with open(finished_txt, 'a') as f:
                for name in batch_img_names:
                    if name not in finished_imgs:
                        f.write(name + '\n')
                        finished_imgs.add(name)
            print(f"✅ Updated finished.txt for batch {i}")

main()