import os
import torch

# gpu_id = os.getenv("GPU_ID", "0")  # 如果未设置GPU_ID，则默认使用设备0
# device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")

CIFAR100_MEAN = (0.48145466, 0.4578275, 0.40821073)
CIFAR100_STD = (0.26862954, 0.26130258, 0.27577711)

# mu = torch.tensor(CIFAR100_MEAN).view(3, 1, 1).to(device)
# std = torch.tensor(CIFAR100_STD).view(3, 1, 1).to(device)
# torch.cuda.set_device(4)

mu = torch.tensor(CIFAR100_MEAN).view(3, 1, 1).cuda()
std = torch.tensor(CIFAR100_STD).view(3, 1, 1).cuda()


def normalize(X):
    return (X - mu) / std

def clip_img_preprocessing(X):
    img_size = 224
    X = torch.nn.functional.interpolate(X, size=(img_size, img_size), mode='bicubic')
    X = normalize(X)
    return X

def rev_normalize(X):
    return X * std + mu

def reverse_clip_img_preprocessing(X):
    X = rev_normalize(X)
    return X

def multiGPU_CLIP_image_logits(images, model, text_tokens, prompter=None, add_prompter=None):
    image_tokens = clip_img_preprocessing(images)
    prompt_token = None if add_prompter is None else add_prompter()
    if prompter is not None:
        image_tokens = prompter(image_tokens)
    return multiGPU_CLIP(None, None, None, model, image_tokens, text_tokens, prompt_token=prompt_token)[0]


def multiGPU_CLIP(args, model_image, model_text, model, images, text_tokens=None, prompt_token=None, dataset_name=None):
    
    if prompt_token is not None:
        bs = images.size(0)
        prompt_token = prompt_token.repeat(bs, 1, 1)
    if args is not None and dataset_name is not None:
        cache_prompts = os.path.join(args.cache, f"refined_{dataset_name.lower()}_prompts.pt")
        cache_wordnet_def = os.path.join(args.cache, f"refined_{dataset_name.lower()}_wn_def.pt")
    else:
        cache_prompts, cache_wordnet_def = None, None
    if cache_prompts is not None and os.path.exists(cache_prompts):
        text_features = torch.load(cache_prompts).to('cpu')
        if args.advanced_text == "wordnet_def":
            a_text_features = torch.load(cache_wordnet_def).to('cpu')
            text_features = (text_features + a_text_features) * 0.5
    else:
        # text_features = model.module.encode_text(text_tokens)
        text_features = model.encode_text(text_tokens)

    text_features = text_features / text_features.norm(dim=1, keepdim=True) # [n_class, d_emb]
    # text_features = text_features.to(device)
    text_features = text_features.cuda()

    # image_features = model.module.encode_image(images, prompt_token)
    image_features = model.encode_image(images, prompt_token)

    image_features = image_features / image_features.norm(dim=1, keepdim=True) # [bs, d_emb]
    # logits_per_image = image_features @ text_features.t() * model.module.logit_scale.exp()
    logits_per_image = image_features @ text_features.t() * model.logit_scale.exp()

    # logits_per_text = text_features @ image_features.t() * model.module.logit_scale.exp()
    logits_per_text = text_features @ image_features.t() * model.logit_scale.exp()


    return logits_per_image, logits_per_text, image_features, text_features


def multiGPU_CLIP_v3(args, model_image, model_text, model, images, text_features=None, text_projector=None, prompt_token=None, dataset_name=None, dataset_class_num=None, batch=None, aug_num=None):
    with torch.no_grad():
        # text_features = text_features / text_features.norm(dim=-1, keepdim=True) # [n_class, d_emb]
        # text_features = text_features.to(device)
        # image_features = model.module.encode_image(images, prompt_token)
        image_features = model.encode_image(images, prompt_token)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True) # [bs, d_emb]

        reshape_images = image_features.view(batch, aug_num, -1).contiguous()
        # reshape_images = image_features.view(aug_num, batch, -1).permute(1, 0, 2).contiguous()

        # #------------------------------------------------
        projection_error = [torch.einsum('cdb,cbk->cdk', text_projector, reshape_images[i].t()[None,:,:].repeat(dataset_class_num, 1, 1)) for i in range(batch)]

        min_error_list = [projection_error[i].sum(-1).sum(-1) for i in range(batch)]
        predict_idx = [min_error_list[i].argmin() for i in range(batch)]

        text_features_reshape = text_features.clone().view(-1, dataset_class_num,  image_features.shape[-1]).permute(1, 0, 2).contiguous()

        used_text_features_list = [text_features_reshape[i] for i in predict_idx]
        
        n_prompt = text_features.shape[0] // text_projector.shape[0]
        Mat_list = list()
        for i in range(batch):
            P = torch.eye(reshape_images[i].shape[0])[torch.randint(0, reshape_images[i].shape[0], (n_prompt,))].to(reshape_images.device)

            Mat_list.append(((P @ reshape_images[i]).T @ used_text_features_list[i]).float())

        Mat = torch.stack(Mat_list, dim=0)
        U, S, Vh = torch.svd(Mat)
        new_reshape_images = list()
        for i in range(batch):
            W = U[i] @ Vh[i].T
            identity = torch.eye(W.size(0)).to(W.device)
            W = W - (W - identity) * 0.9
            new_reshape_images.append(reshape_images[i] @ W.T)
        reshape_images = torch.stack(new_reshape_images, dim=0)

        # W = torch.einsum('bdr,brk->bdk', U, Vh.permute(0, 2, 1))
        # identity = torch.eye(W.size(1)).to(device)
        # identity = identity[None,:,:].repeat(W.shape[0], 1, 1)
        # W = W - (W - identity) * 0.85
        # reshape_images = torch.einsum('bnd, bkk->bnk', reshape_images, W)
        # #-----------------------------------------------------------------------
        new_reshape_images = list()
        U, S, V = torch.linalg.svd((text_features.T @ text_features).float(), )
        P = U[:,:256] @ U[:,:256].T
        
        for i in range(batch):
            new_reshape_images.append(reshape_images[i] @ P)
        reshape_images = torch.stack(new_reshape_images, dim=0)
        # #---------------------------------------------------------------------------

        reshape_images = reshape_images / reshape_images.norm(dim=-1, keepdim=True) # [bs, d_emb]

        logits = reshape_images.view(-1, image_features.shape[-1]).contiguous() @ text_features.t() * model.logit_scale.exp()
        
        # logits_per_sample = logits.view(aug_num, batch,  -1).permute(1, 0, 2).contiguous() 
        # logits_per_sample = logits_per_sample.view(batch, aug_num, dataset_class_num, -1,).permute(0, 1, 3, 2).contiguous() 

        logits_per_sample = logits.view(batch, aug_num,  -1).contiguous() 
        logits_per_sample = logits_per_sample.view(batch, aug_num, dataset_class_num, -1,).permute(0, 1, 3, 2).contiguous() 



        # logits_per_image = image_features @ text_features.t() * model.logit_scale.exp()
        # logits_per_image = logits_per_image.view(image_features.shape[0],  dataset_class_num, -1,).permute(0, 2, 1).contiguous() 

        # # logits_per_text = text_features @ image_features.t() * model.module.logit_scale.exp()
        # logits_per_text = text_features @ image_features.t() * model.logit_scale.exp()
        # logits_per_text = logits_per_text.view( -1, dataset_class_num, image_features.shape[0]).contiguous() 

        return logits_per_sample, None, image_features, text_features



def multiGPU_CLIP_v2(args, model_image, model_text, model, images, text_features=None, text_projector=None, prompt_token=None, dataset_name=None, dataset_class_num=None, batch=None, aug_num=None):
    with torch.no_grad():

        # text_features = text_features / text_features.norm(dim=-1, keepdim=True) # [n_class, d_emb]
        # text_features = text_features.to(device)
        # image_features = model.module.encode_image(images, prompt_token)
        image_features = model.encode_image(images, prompt_token)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True) # [bs, d_emb]

        reshape_images = image_features.view(batch, aug_num, -1).contiguous()
        # reshape_images = image_features.view(aug_num, batch, -1).permute(1, 0, 2).contiguous()

        # #------------------------------------------------
        # projection_error = [torch.einsum('cdb,cbk->cdk', text_projector, reshape_images[i].t()[None,:,:].repeat(dataset_class_num, 1, 1)) for i in range(batch)]

        # min_error_list = [projection_error[i].sum(-1).sum(-1) for i in range(batch)]
        # predict_idx = [min_error_list[i].argmin() for i in range(batch)]

        # text_features_reshape = text_features.clone().view(-1, dataset_class_num,  image_features.shape[-1]).permute(1, 0, 2).contiguous()

        # used_text_features_list = [text_features_reshape[i] for i in predict_idx]
        
        # n_prompt = text_features.shape[0] // text_projector.shape[0]
        # Mat_list = list()
        # for i in range(batch):
        #     P = torch.eye(reshape_images[i].shape[0])[torch.randint(0, reshape_images[i].shape[0], (n_prompt,))].to(reshape_images.device)
            
        #     Mat_list.append(((P @ reshape_images[i]).T @ used_text_features_list[i]).float())

        # Mat = torch.stack(Mat_list, dim=0)
        # U, S, Vh = torch.svd(Mat)
        # new_reshape_images = list()
        # for i in range(batch):
        #     W = U[i] @ Vh[i].T
        #     identity = torch.eye(W.size(0)).to(W.device)
        #     W = W - (W - identity) * 0.9
        #     new_reshape_images.append(reshape_images[i] @ W.T)
        # reshape_images = torch.stack(new_reshape_images, dim=0)

        # # W = torch.einsum('bdr,brk->bdk', U, Vh.permute(0, 2, 1))
        # # identity = torch.eye(W.size(1)).to(device)
        # # identity = identity[None,:,:].repeat(W.shape[0], 1, 1)
        # # W = W - (W - identity) * 0.85
        # # reshape_images = torch.einsum('bnd, bkk->bnk', reshape_images, W)
        # #-----------------------------------------------------------------------
        new_reshape_images = list()
        U, S, V = torch.linalg.svd((text_features.T @ text_features).float(), )
        # print(args)
        P = U[:,:args.SVD_comp] @ U[:,:args.SVD_comp].T
        
        for i in range(batch):
            new_reshape_images.append(reshape_images[i] @ P)
        reshape_images = torch.stack(new_reshape_images, dim=0)
        #---------------------------------------------------------------------------

        reshape_images = reshape_images / reshape_images.norm(dim=-1, keepdim=True) # [bs, d_emb]

        logits = reshape_images.view(-1, image_features.shape[-1]).contiguous() @ text_features.t() * model.logit_scale.exp()
        
        # logits_per_sample = logits.view(aug_num, batch,  -1).permute(1, 0, 2).contiguous() 
        # logits_per_sample = logits_per_sample.view(batch, aug_num, dataset_class_num, -1,).permute(0, 1, 3, 2).contiguous() 

        logits_per_sample = logits.view(batch, aug_num,  -1).contiguous() 
        logits_per_sample = logits_per_sample.view(batch, aug_num, dataset_class_num, -1,).permute(0, 1, 3, 2).contiguous() 



        # logits_per_image = image_features @ text_features.t() * model.logit_scale.exp()
        # logits_per_image = logits_per_image.view(image_features.shape[0],  dataset_class_num, -1,).permute(0, 2, 1).contiguous() 

        # # logits_per_text = text_features @ image_features.t() * model.module.logit_scale.exp()
        # logits_per_text = text_features @ image_features.t() * model.logit_scale.exp()
        # logits_per_text = logits_per_text.view( -1, dataset_class_num, image_features.shape[0]).contiguous() 

        return logits_per_sample, None, image_features, text_features

def multiGPU_CLIP_v2_orig(args, model_image, model_text, model, images, text_features=None, text_projector=None, prompt_token=None, dataset_name=None, dataset_class_num=None, batch=None, aug_num=None):
    with torch.no_grad():

        # text_features = text_features / text_features.norm(dim=-1, keepdim=True) # [n_class, d_emb]
        # text_features = text_features.to(device)
        # image_features = model.module.encode_image(images, prompt_token)
        image_features = model.encode_image(images, prompt_token)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True) # [bs, d_emb]

        reshape_images = image_features.view(batch, aug_num, -1).contiguous()
        # reshape_images = image_features.view(aug_num, batch, -1).permute(1, 0, 2).contiguous()

    
        reshape_images = reshape_images / reshape_images.norm(dim=-1, keepdim=True) # [bs, d_emb]

        logits = reshape_images.view(-1, image_features.shape[-1]).contiguous() @ text_features.t() * model.logit_scale.exp()
        

        logits_per_sample = logits.view(batch, aug_num,  -1).contiguous() 
        logits_per_sample = logits_per_sample.view(batch, aug_num, dataset_class_num, -1,).permute(0, 1, 3, 2).contiguous() 


        return logits_per_sample, None, image_features, text_features


def multiGPU_CLIP_v1(args, model_image, model_text, model, images, text_features=None, prompt_token=None, dataset_name=None, dataset_class_num=None):
    with torch.no_grad():

        text_features = text_features / text_features.norm(dim=-1, keepdim=True) # [n_class, d_emb]
        # text_features = text_features.to(device)
        text_features = text_features.cuda()

        # image_features = model.module.encode_image(images, prompt_token)
        image_features = model.encode_image(images, prompt_token)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True) # [bs, d_emb]
        # logits_per_image = image_features @ text_features.t() * model.module.logit_scale.exp()

        logits_per_image = image_features @ text_features.t() * model.logit_scale.exp()
        logits_per_image = logits_per_image.view(image_features.shape[0],  dataset_class_num, -1,).permute(0, 2, 1).contiguous() 

        # logits_per_text = text_features @ image_features.t() * model.module.logit_scale.exp()
        logits_per_text = text_features @ image_features.t() * model.logit_scale.exp()
        logits_per_text = logits_per_text.view( -1, dataset_class_num, image_features.shape[0]).contiguous() 

        return logits_per_image, logits_per_text, image_features, text_features

def kl_div(p_logits, q_logits):
    # p_logits, q_logits [bs, n_class] both have been softmax normalized
    kl_divs = (p_logits * (p_logits.log() - q_logits.log())).sum(dim=1) # [bs,]
    return kl_divs.mean()

def get_loss_general(tgt_logits, a_images, model_image_copy, text_features):
    # feed the perturbed image into the original visual encoder, regularise the predictive logits
    image_features = model_image_copy(a_images) # [bs, d_emb]
    image_features = image_features / image_features.norm(dim=1, keepdim=True)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)
    logits_per_image_ = image_features @ text_features.t() * model_image_copy.module.logit_scale.exp() # [bs, n_class]
    l_general = kl_div(tgt_logits.softmax(dim=1), logits_per_image_.softmax(dim=1))
    # l_general = criterion_(F.log_softmax(logits_per_image_, dim=1), F.softmax(tgt_logits))
    return l_general

def get_loss_clean(clean_images, tgt_logits, model, text_features, prompt_token=None):
    # feed the clean image into the visual encoder, regularise the predictive logits
    image_features = model.module.encode_image(clean_images, prompt_token) # [bs, d_emb]
    image_features = image_features / image_features.norm(dim=1, keepdim=True)
    logits_per_image = image_features @ text_features.t() * model.module.logit_scale.exp() # [bs, n_class]
    l_clean = kl_div(tgt_logits.softmax(dim=1), logits_per_image.softmax(dim=1))
    # l_clean = criterion_(F.log_softmax(logits_per_image, dim=1), F.softmax(tgt_logits, dim=1))
    return l_clean