from re import L
from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np
from torch.utils import data
from datasets import Cityscapes, gta5, endovis18, nalendovis
from utils import ext_transforms as et
from metrics import StreamSegMetrics
import torch
import torch.nn as nn
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import pickle
from utils.utils import denormalize
from torchvision.utils import save_image
from torch.nn.functional import unfold
import clip
import torch.nn.functional as F


from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

def get_argparser():
    parser = argparse.ArgumentParser()

    # Dataset Options
    parser.add_argument("--data_root", type=str, default='./datasets/data',
                        help="path to Dataset")
    parser.add_argument("--dataset", type=str, default='cityscapes',
                        choices=['cityscapes','ACDC','gta5','nalendovis','endovis18'], help='Name of dataset')
    parser.add_argument("--ACDC_sub", type=str, default="night",
                        help = "specify which subset of ACDC  to use")

    # Deeplab Options
    available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \
                              not (name.startswith("__") or name.startswith('_')) and callable(
                              network.modeling.__dict__[name])
                              )
    parser.add_argument("--model", type=str, default='deeplabv3plus_resnet_clip',
                        choices=available_models, help='model name')
    parser.add_argument("--BB", type = str, default = "RN50",
                        help = "backbone of the segmentation network")

    # Train Options
    parser.add_argument("--test_only", action='store_true', default=False)
    parser.add_argument("--total_itrs", type=int, default=200e3,
                        help="epoch number (default: 200k)")
    parser.add_argument("--lr", type=float, default=0.1,
                        help="learning rate (default: 0.1)")
    parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'],
                        help="learning rate scheduler policy")
    parser.add_argument("--step_size", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=8,
                        help='batch size (default: 16)')
    parser.add_argument("--val_batch_size", type=int, default=4,
                        help='batch size for validation (default: 4)')
    parser.add_argument("--crop_size", type=int, default=768)

    parser.add_argument("--ckpt", default=None, type=str,
                        help="restore from checkpoint")

    parser.add_argument("--continue_training", action='store_true', default=False)

    parser.add_argument("--loss_type", type=str, default='cross_entropy',
                        choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)")
    parser.add_argument("--gpu_id", type=str, default='0',
                        help="GPU ID")
    parser.add_argument("--weight_decay", type=float, default=1e-4,
                        help='weight decay (default: 1e-4)')
    parser.add_argument("--random_seed", type=int, default=1,
                        help="random seed (default: 1)")
    parser.add_argument("--val_interval", type=int, default=1000,
                        help="epoch interval for eval (default: 100)")
    parser.add_argument("--forward_pass",action='store_true',default=False,
                        help="forward pass to update BN statistics")
    parser.add_argument("--save_val_results", action='store_true', default=False,
                        help="save segmentation results to \"./results\"")
    parser.add_argument("--freeze_BB", action='store_true',default=False,
                        help="Freeze the backbone when training")
    parser.add_argument("--ckpts_path", type = str ,
                        help="path for checkpoints saving")
    parser.add_argument("--data_aug", action='store_true', default=False)
    #validation
    parser.add_argument("--val_results_dir", type=str,help="Folder name for validation results saving")
    #Augmented features
    parser.add_argument("--train_aug",action='store_true',default=False,
                        help="train on augmented features using CLIP")
    parser.add_argument("--path_mu_sig", type=str)
    parser.add_argument("--mix", action='store_true',default=False,
                        help="mix statistics")
    parser.add_argument("--temperature", type=float, default=10)
    parser.add_argument("--proj_lr", type=float, default=1e-3)

    return parser


def get_dataset(dataset,data_root,crop_size,ACDC_sub="night",data_aug=True):
    """ Dataset And Augmentation
    """
    if dataset == 'cityscapes':
        if data_aug:
            train_transform = et.ExtCompose([
                et.ExtRandomCrop(size=(crop_size, crop_size)),
                et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
                et.ExtRandomHorizontalFlip(),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                std=[0.26862954, 0.26130258, 0.27577711]),
            ])
        else:
            train_transform = et.ExtCompose([
                et.ExtRandomCrop(size=(crop_size, crop_size)),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                std=[0.26862954, 0.26130258, 0.27577711]),
            ])

        val_transform = et.ExtCompose([
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
        ])

        train_dst = Cityscapes(root=data_root,dataset=dataset,
                               split='train', transform=train_transform)
        val_dst = Cityscapes(root=data_root,dataset=dataset,
                             split='val', transform=val_transform)

    if dataset == 'nalendovis':
        if data_aug:
            train_transform = et.ExtCompose([
                et.ExtRandomCrop(size=(crop_size, crop_size)),
                et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
                et.ExtRandomHorizontalFlip(),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                std=[0.26862954, 0.26130258, 0.27577711]),
            ])
        else:
            train_transform = et.ExtCompose([
                et.ExtRandomCrop(size=(crop_size, crop_size)),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                std=[0.26862954, 0.26130258, 0.27577711]),
            ])

        val_transform = et.ExtCompose([
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
        ])

        train_dst = nalendovis.Nalendovis(root=data_root,dataset=dataset,
                               split='train', transform=train_transform)
        val_dst = nalendovis.Nalendovis(root=data_root,dataset=dataset,
                             split='train', transform=val_transform)
    
    if dataset == 'ACDC':
        train_transform = et.ExtCompose([
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
        ])
        val_transform = et.ExtCompose([
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
        ])

        train_dst = Cityscapes(root=data_root,dataset=dataset,
                               split='train', transform=train_transform, ACDC_sub = ACDC_sub)
        val_dst = Cityscapes(root=data_root,dataset=dataset,
                             split='val', transform=val_transform, ACDC_sub = ACDC_sub)

    if dataset == "gta5":
        
        if data_aug:
            train_transform = et.ExtCompose([
                et.ExtRandomCrop(size=(768, 768)),
                et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
                et.ExtRandomHorizontalFlip(),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711]),
            ])
        else:
            train_transform = et.ExtCompose([
                et.ExtRandomCrop(size=(768, 768)),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711]),
            ])

        val_transform = et.ExtCompose([
            et.ExtCenterCrop(size=(1046, 1914)),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
        ])

        train_dst = gta5.GTA5DataSet(data_root, 'datasets/gta5_list/gtav_split_train.txt',transform=train_transform)
        val_dst = gta5.GTA5DataSet(data_root, 'datasets/gta5_list/gtav_split_val.txt',transform=val_transform)

    if dataset == "endovis18":
        
        if data_aug:
            train_transform = et.ExtCompose([
                et.ExtRandomCrop(size=(768, 768)),
                et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
                et.ExtRandomHorizontalFlip(),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711]),
            ])
        else:
            train_transform = et.ExtCompose([
                et.ExtRandomCrop(size=(768, 768)),
                et.ExtToTensor(),
                et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                    std=[0.26862954, 0.26130258, 0.27577711]),
            ])

        val_transform = et.ExtCompose([
            et.ExtCenterCrop(size=(1046, 1914)),
            et.ExtToTensor(),
            et.ExtNormalize(mean=[0.48145466, 0.4578275, 0.40821073],
                            std=[0.26862954, 0.26130258, 0.27577711]),
        ])

        train_dst = endovis18.Endovis18DataSet(data_root, '/mnt/sdc/zhangdong/code/HRDA/data_me/endovis18_best/endovis_file_list.txt',transform=train_transform)
        val_dst = endovis18.Endovis18DataSet(data_root, '/mnt/sdc/zhangdong/code/HRDA/data_me/endovis18_best/endovis_file_list.txt',transform=val_transform)
    
    return train_dst, val_dst

def validate(opts, model, loader, device, metrics):
    """Do validation and return specified samples"""
    metrics.reset()
    if opts.save_val_results:
        if not os.path.exists(opts.val_results_dir):
            os.mkdir(opts.val_results_dir)
        img_id = 0

    with torch.no_grad():

        for i, (im_id, tg_id, images, labels) in tqdm(enumerate(loader), total=len(loader)):
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)
            
            outputs,features,deeplab_feature = model(images)
            preds = outputs.detach().max(dim=1)[1].cpu().numpy()
            targets = labels.cpu().numpy()
           
            metrics.update(targets, preds)
            
            if opts.save_val_results:
                for j in range(len(images)):

                    target = targets[j]
                    pred = preds[j]

                    target = loader.dataset.decode_target(target).astype(np.uint8)
                    pred = loader.dataset.decode_target(pred).astype(np.uint8)

                    Image.fromarray(target).save(opts.val_results_dir+'/%d_target.png' % img_id)
                    Image.fromarray(pred).save(opts.val_results_dir+'/%d_pred.png' % img_id)

                    images[j] = denormalize(images[j],mean=[0.48145466, 0.4578275, 0.40821073],
                                std=[0.26862954, 0.26130258, 0.27577711])
                    save_image(images[j],opts.val_results_dir+'/%d_image.png' % img_id)

                    fig = plt.figure()
                    plt.axis('off')
                    plt.imshow(pred, alpha=0.7)
                    ax = plt.gca()
                    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    #plt.savefig(opts.val_results_dir+'/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0)
                    plt.close()
                    img_id += 1

        score = metrics.get_results()
    return score

#将视觉特征从 CNN 输出空间（例如 256 维）映射到与 CLIP 文本嵌入一致的空间
class PROJECTOR(nn.Module):
    def __init__(self, indim, outdim, shape=None):
        super(PROJECTOR, self).__init__()

        layers = [nn.Linear(indim, shape[0] if shape else outdim)] 
        if shape:
            for i in range(len(shape) - 1):
                layers.append(nn.Linear(shape[i], shape[i+1]))
                layers.append(nn.ReLU())
            layers.append(nn.Linear(shape[-1], outdim))
        self.projector = nn.Sequential(*layers)

    def forward(self, feature):
        feature = feature.permute(1, 2, 0)  
        feature =  self.projector(feature)
        feature = feature.permute(2, 0, 1) 
        return feature

def extract_and_pool_features(raw_class, raw_label, raw_feature, template='a {} in the {}'):
    class_name = [Cityscapes.train_id_to_name[class_idx] for class_idx in raw_class]  
    filled_templates = [template.format(item) for item in class_name]

    new_label_map = torch.zeros_like(raw_label, dtype=torch.long)

    pooled_features_list = []

    for new_class_id, old_class_id in enumerate(raw_class):
        mask = (raw_label == old_class_id)

        new_label_map = torch.where(mask, torch.tensor(new_class_id, dtype=torch.long, device=mask.device), new_label_map)

        class_features = raw_feature[:, mask] 

        pooled_features = torch.mean(class_features, dim=1)  
        pooled_features_list.append(pooled_features)

    special_value = 255
    new_label_map = torch.where(raw_label == special_value, 
                            torch.tensor(special_value, dtype=torch.long, device=new_label_map.device), 
                            new_label_map)


    pooled_features_tensor = torch.stack(pooled_features_list)

    return new_label_map, filled_templates, pooled_features_tensor

imagenet_templates = [
    'a bad photo of a {}.',
    'a photo of many {}.',
    'a sculpture of a {}.',
    'a photo of the hard to see {}.',
    'a low resolution photo of the {}.',
    'a rendering of a {}.',
    'graffiti of a {}.',
    'a bad photo of the {}.',
    'a cropped photo of the {}.',
    'a tattoo of a {}.',
    'the embroidered {}.',
    'a photo of a hard to see {}.',
    'a bright photo of a {}.',
    'a photo of a clean {}.',
    'a photo of a dirty {}.',
    'a dark photo of the {}.',
    'a drawing of a {}.',
    'a photo of my {}.',
    'the plastic {}.',
    'a photo of the cool {}.',
    'a close-up photo of a {}.',
    'a black and white photo of the {}.',
    'a painting of the {}.',
    'a painting of a {}.',
    'a pixelated photo of the {}.',
    'a sculpture of the {}.',
    'a bright photo of the {}.',
    'a cropped photo of a {}.',
    'a plastic {}.',
    'a photo of the dirty {}.',
    'a jpeg corrupted photo of a {}.',
    'a blurry photo of the {}.',
    'a photo of the {}.',
    'a good photo of the {}.',
    'a rendering of the {}.',
    'a {} in a video game.',
    'a photo of one {}.',
    'a doodle of a {}.',
    'a close-up photo of the {}.',
    'a photo of a {}.',
    'the origami {}.',
    'the {} in a video game.',
    'a sketch of a {}.',
    'a doodle of the {}.',
    'a origami {}.',
    'a low resolution photo of a {}.',
    'the toy {}.',
    'a rendition of the {}.',
    'a photo of the clean {}.',
    'a photo of a large {}.',
    'a rendition of a {}.',
    'a photo of a nice {}.',
    'a photo of a weird {}.',
    'a blurry photo of a {}.',
    'a cartoon {}.',
    'art of a {}.',
    'a sketch of the {}.',
    'a embroidered {}.',
    'a pixelated photo of a {}.',
    'itap of the {}.',
    'a jpeg corrupted photo of the {}.',
    'a good photo of a {}.',
    'a plushie {}.',
    'a photo of the nice {}.',
    'a photo of the small {}.',
    'a photo of the weird {}.',
    'the cartoon {}.',
    'art of the {}.',
    'a drawing of the {}.',
    'a photo of the large {}.',
    'a black and white photo of a {}.',
    'the plushie {}.',
    'a dark photo of a {}.',
    'itap of a {}.',
    'graffiti of the {}.',
    'a toy {}.',
    'itap of my {}.',
    'a photo of a cool {}.',
    'a photo of a small {}.',
    'a tattoo of the {}.',
]

#将用户提供的描述 text 填入一组语言模板 templates 中，生成一批用于 CLIP 文本编码的语句。
def compose_text_with_templates(text: str, templates) -> list:
    return [template.format(text) for template in templates]

def main():
    opts = get_argparser().parse_args()
    
    # os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    clip_model, preprocess = clip.load(opts.BB, device, jit=False)
    clip_model.to(device)
    clip_model.eval()  # 禁用 dropout/bn 等训练态
    for p in clip_model.parameters():
        p.requires_grad = False  # 明确不参与训练
    
    # Setup random seed
    # INIT
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
  
    train_dst,val_dst = get_dataset(opts.dataset,opts.data_root,opts.crop_size,opts.ACDC_sub,
                                    data_aug=opts.data_aug)

    train_loader = data.DataLoader(
        train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=4,
    drop_last=True)  # drop_last=True to ignore single-image batches.

    val_loader = data.DataLoader(
        val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=4)
    
    print("Dataset: %s, Train set: %d, Val set: %d" %
        (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    model = network.modeling.__dict__[opts.model](num_classes=19, BB= opts.BB,replace_stride_with_dilation=[False,False,True])
    model.backbone.attnpool = nn.Identity()
    model.to(device)
    #fix the backbone
    if opts.freeze_BB:
        for param in model.backbone.parameters():
            param.requires_grad = False
        model.backbone.eval()
        
    # Save RNG state BEFORE clip.load
    cpu_rng_state = torch.get_rng_state()
    cuda_rng_state = torch.cuda.get_rng_state()
    np_rng_state = np.random.get_state()
    py_rng_state = random.getstate()
    
    
    
    # Restore RNG states to "pre-clip" state
    torch.set_rng_state(cpu_rng_state)
    torch.cuda.set_rng_state(cuda_rng_state)
    np.random.set_state(np_rng_state)
    random.setstate(py_rng_state)
    
    # Set up metrics
    metrics = StreamSegMetrics(19)

    # Set up optimizer
    if opts.freeze_BB:
        optimizer = torch.optim.SGD(params=[
            {'params': model.classifier.parameters(), 'lr': opts.lr},
            ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    else:
        optimizer = torch.optim.SGD(params=[
            {'params': model.backbone.parameters(), 'lr': 0.001 * opts.lr},
            {'params': model.classifier.parameters(), 'lr': opts.lr},
            ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.9)

    # Set up criterion
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)
    
    if not opts.test_only:
        utils.mkdir(opts.ckpts_path)
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model.to(device)
    
    # ==========   Train Loop   ==========#

    if opts.test_only:
       
        model.eval()

        val_score = validate(
            opts=opts, model=model, loader=val_loader, device=device, metrics=metrics)

        print(metrics.to_str(val_score))
        print(val_score["Mean IoU"])
        print(val_score["Class IoU"])
        return

    interval_loss = 0

    if opts.train_aug:
        with open(opts.path_mu_sig+'/category_style_stats_all.pkl', 'rb') as f:
            loaded_dict_patches = pickle.load(f)

    fea_proj = PROJECTOR(304, 1024, [512])
    fea_proj.to(device)
        
    #初始化 Optimizer（对 projector 和 PIN 进行优化）
    optimizer_proj = torch.optim.SGD(params=[
        {'params': fea_proj.parameters(), 'lr': opts.proj_lr},
    ], lr= 1, momentum=0.9, weight_decay=opts.weight_decay)
    

    class_text = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign',
                'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']
    
    relu = nn.ReLU(inplace=True)

    while True:  # cur_itrs < opts.total_itrs:
    # =====  Train  =====
    
        if opts.freeze_BB:
            model.classifier.train()
        else:
            model.train()

        cur_epochs += 1

        for (im_id, tg_id, images, labels) in train_loader:
            
            cur_itrs += 1
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)
            
            optimizer.zero_grad()
            optimizer_proj.zero_grad()
            if opts.train_aug:
                labels_ = labels.unsqueeze(1)  # (B,1,768,768)
                # lbl_patches = divide_in_patches(labels_,3)

                lbl_patches = unfold(labels_, kernel_size=256, stride=256).permute(-1,0,1)
                lbl_patches = lbl_patches.reshape(lbl_patches.shape[0],lbl_patches.shape[1],1,256,256) #### (div*div, B, 1, H/div, W/div)
                lbl_patches = lbl_patches.to(torch.long)

                most_list = []
                for j in range(len(lbl_patches)): ### iterate on dim 0 (div*div)
                    most = [Cityscapes.name(torch.mode(torch.flatten(lbl_patches[j][k])).values) if torch.mode(torch.flatten(lbl_patches[j][k])).values != 255 else 255 for k in range(lbl_patches[0].shape[0])]
                    most_list.append(most) #len=div*div , each element list of B elements
            

                beta_dist = torch.distributions.beta.Beta(0.1, 0.1)
                s = beta_dist.sample((opts.batch_size, 256, 1, 1)).to(device)

                outputs,features,deeplab_features = model(images, transfer=opts.train_aug,mix=True,most_list=most_list,saved_params=loaded_dict_patches,activation=relu,s=s)
                template ='{} in the rainy scene.'
                
                loss_pixel = 0
                for index in range(len(labels)):
                    cur_label_raw = labels[index]
                    #[1, H, W]
                    # print("deeplab_features_proj",deeplab_features.shape)
                    cur_label = torch.nn.functional.interpolate(cur_label_raw.to(dtype=torch.float).unsqueeze(0).unsqueeze(0), \
                                        size=[deeplab_features.shape[-2], deeplab_features.shape[-1]] , mode='nearest') 
                    cur_label = cur_label.to(dtype=torch.long)
                    #[H, W]
                    cur_label = cur_label.squeeze(0).squeeze(0)

                    del cur_label_raw 
                    unique_classes = list(cur_label.unique())
                    if 255 in unique_classes:
                        unique_classes.remove(255)

                    # print(deeplab_features[index].shape) torch.Size([256, 192, 192])
                    #获取第 index 张图像上的特征图，形状为 [1, 1024, 64, 64]
                    cur_features = fea_proj(deeplab_features[index])

                    all_class_ids = list(range(len(class_text)))  # [0, 1, 2, 3, 4, 5, 6, 7]
                    
                    cur_gt, cur_text, cur_prototype = extract_and_pool_features(all_class_ids, cur_label, cur_features, \
                                                                template = template)
                
                    encoded_vectors = []
                    class_encoded_vectors = []
                    
                    with torch.no_grad():
                        for input_txt in cur_text:
                            target_txt = compose_text_with_templates(input_txt, imagenet_templates)
                            tokens_txt = clip.tokenize(target_txt).to(device)
                            encoded_text = clip_model.encode_text(tokens_txt).mean(axis=0, keepdim=True).detach()
                            encoded_text = encoded_text / encoded_text.norm(dim=-1, keepdim=True)
                            encoded_vectors.append(encoded_text.cpu())
                    # combine
                    #text_matrix.shape = [num_classes, 1024]
                    text_matrix = torch.cat(encoded_vectors, dim=0)
                    text_matrix = text_matrix.to(device)
                    del encoded_vectors 
                
                    cur_features = cur_features.to(dtype=torch.float32)
                    text_matrix = text_matrix.to(dtype=torch.float32)
                    cur_features = F.normalize(cur_features, 2, 0)  
                    logits = torch.einsum('ij,jkl->ikl', text_matrix, cur_features) * opts.temperature
                    # del text_matrix 
                    class_num = logits.shape[0] 
                    assert class_num==len(all_class_ids)
                    cur_gt = cur_gt.to(device)
                    cl_loss = criterion(logits.view(class_num, -1).unsqueeze(0), cur_gt.\
                                        view(-1).unsqueeze(0))
                    # cl_loss /= (logits.shape[1] * logits.shape[2])
                    loss_pixel += cl_loss
                loss_CE = loss_pixel / len(labels)
                
            else:
                outputs,features,deeplab_features = model(images)

            labels = labels.to(device, dtype=torch.long)
            loss = criterion(outputs, labels)
            loss_all = loss + loss_CE * 0.5
            loss_all.backward()
            optimizer.step()
            optimizer_proj.step()
            writer.add_scalar("loss",loss,cur_itrs)
            np_loss = loss_all.detach().cpu().numpy()
            interval_loss += np_loss
            
            if (cur_itrs) % 10 == 0:
                interval_loss = interval_loss / 10
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                    (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

            if (cur_itrs) % opts.val_interval == 0 and not opts.train_aug:
                save_ckpt(opts.ckpts_path+'/latest_%s_%s.pth' %
                        (opts.model, opts.dataset))
                print("validation...")
                model.eval()
               
                val_score = validate(
                    opts=opts, model=model, loader=val_loader,device=device, metrics=metrics
                    )

                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt(opts.ckpts_path+'/best_%s_%s.pth' %
                            (opts.model, opts.dataset))

                writer.add_scalar("mIoU", val_score['Mean IoU'] ,cur_itrs)

                if opts.freeze_BB:
                    model.classifier.train()
                else:
                    model.train()
                    
            if opts.train_aug and cur_itrs == opts.total_itrs:
                save_ckpt(opts.ckpts_path+'/adapted_%s_%s.pth' %
                        (opts.model, opts.dataset))
            
            scheduler.step()

            if cur_itrs >= opts.total_itrs:
                return
            

if __name__ == '__main__':
    main()