import os
import copy
import time
from tqdm import tqdm
import random
import numpy as np
import torch

import clip.clip as clip

from src.args import parse_arguments
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.models.zeroshot import get_zeroshot_classifier
from src.models.eval import evaluate
from src.models.modeling import ClassificationHead, ImageEncoder, ImageClassifier, ImageEncoderMLP, PLIPImageEncoderMLP
from src.models.utils import cosine_lr, torch_load, LabelSmoothing

import src.datasets as datasets
import torchvision.transforms as T
from PIL import Image

def compute_skewness(X):
    mean_X = torch.mean(X, dim=0)
    std_X = torch.std(X, dim=0)
    skewness = torch.mean(((X - mean_X) / std_X) ** 3, dim=0)
    return skewness

def compute_kurtosis(X):
    mean_X = torch.mean(X, dim=0)
    std_X = torch.std(X, dim=0)
    kurtosis = torch.mean(((X - mean_X) / std_X) ** 4, dim=0) - 3
    return kurtosis


import torch
import torch.nn.functional as F

def cal_dis_loss(h, y):
    num_classes = h.size(1)  # 获取类别数量 |Y|
    
    # 取出正样本的分数 h(x)_y
    pos_score = h[torch.arange(h.size(0)), y]
    
    # 计算负样本的平均分数
    neg_scores = (h.sum(dim=1) - pos_score) / (num_classes - 1)
    
    # 计算公式中的exp部分
    loss_component = torch.log1p(torch.exp(pos_score - neg_scores))
    
    # 损失函数除以 log(2)
    loss = loss_component / torch.log(torch.tensor(2.0))
    
    # 返回最终损失值的平均值
    return loss.mean()




def sk_finetune(args):
    print(args)
    # set seeds
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    if args.plip:
        image_encoder = PLIPImageEncoderMLP(args, keep_lang=args.model_source in ['clip', 'open_clip'])
    else:
        image_encoder = ImageEncoderMLP(args, keep_lang=args.model_source in ['clip', 'open_clip'])
    # image_encoder = ImageEncoder(args, keep_lang=args.model_source in ['clip', 'open_clip'])
    if args.model_source == 'timm':
        weights = torch.nn.init.kaiming_uniform(torch.empty((datasets.dataset2classes[args.train_dataset], image_encoder.model.num_features)))
    else:
        if args.plip:
            weights = torch.nn.init.kaiming_uniform(torch.empty((datasets.dataset2classes[args.train_dataset], 512)))
        else:
            weights = torch.nn.init.kaiming_uniform(torch.empty((datasets.dataset2classes[args.train_dataset], image_encoder.model.embed_dim)))
            delattr(image_encoder.model, 'transformer') 

    classification_head = ClassificationHead(False, weights)
    image_classifier = ImageClassifier(image_encoder, classification_head)
    image_classifier.return_mid_feats = True
    image_classifier.process_images = True
    
    # freeze image encoder parameters
    if args.freeze_encoder:
        print('Fine-tuning mlp classifier')
        
        if args.plip:
            for param in image_encoder.model.model.parameters():
                param.requires_grad = False
        else:
            for param in image_encoder.model.parameters():
                param.requires_grad = False

    model = image_classifier
    
    input_key = 'images'
    if args.plip:
        transform = T.Compose([
            T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),  # 调整图像大小
            T.CenterCrop(224),  # 中心裁剪为 (224, 224)
            T.Lambda(lambda image: image.convert("RGB")),  # 将图像转换为 RGB 模式
            T.ToTensor(),  # 转换为 Tensor
            T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))])

        preprocess_fn = transform
    else:
        preprocess_fn = image_encoder.train_preprocess

    image_enc = None
    print_every = 200
    
    dataset_class = getattr(datasets, args.train_dataset)
    dataset = dataset_class(
        preprocess_fn,
        location=args.data_location,
        batch_size=args.batch_size,
        num_shots=args.num_shots,
        noise_ratio=args.noise_ratio,
    )
    num_batches = len(dataset.train_loader)
    
    model = model.cuda()
    # image_classifier.cuda()
    devices = list(range(torch.cuda.device_count()))
    print('Using devices', devices)
    model = torch.nn.DataParallel(model, device_ids=devices)
    model.train()
    
    print("start training")
    ce_loss_fn = torch.nn.CrossEntropyLoss()
    
    skewness_weight = args.skewness_weight
    kurtosis_weight = args.kurtosis_weight
    dis_weight = args.dis_weight
    
    # params = [p for p in model.parameters() if p.requires_grad] + [p for p in image_classifier.parameters()]
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)
    scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches)

    for epoch in range(args.epochs):

        model.train()
        if args.plip:
            model.module.image_encoder.model.model.eval()
        else:
            model.module.image_encoder.model.eval()
        
        data_loader = get_dataloader(
            dataset, is_train=True, args=args, image_encoder=image_enc)

        for i, batch in enumerate(data_loader):
            start_time = time.time()
            
            step = i + epoch * num_batches
            scheduler(step)
            optimizer.zero_grad()
            
            
            batch = maybe_dictionarize(batch)
 
            labels = batch["labels"]
            inputs = batch["images"]
      
            inputs = inputs.cuda()
            labels = labels.cuda()
            data_time = time.time() - start_time


            # features
            feats, mlp_feats = model.module.forward_encoder(inputs)
            
            # ce loss
            logits = model.module.forward_cls_head(mlp_feats)
            ce_loss = ce_loss_fn(logits, labels)
            
            # total loss
            loss = ce_loss
            
            if skewness_weight:
                skewness = compute_skewness(mlp_feats)
                skewness_penalty = -torch.mean(skewness ** 2)
                loss += skewness_weight * skewness_penalty 
            
            if kurtosis_weight:
                kurtosis = compute_kurtosis(mlp_feats)
                kurtosis_penalty = -torch.mean((kurtosis - 3) ** 2)
                loss += kurtosis_weight * kurtosis_penalty 
                    
            if dis_weight:
                dis_loss = cal_dis_loss(logits, labels)
                loss += dis_weight * dis_loss
                
            loss.backward()
            optimizer.step()
            batch_time = time.time() - start_time

            if i % print_every == 0:
                percent_complete = 100 * i / len(data_loader)
                print(
                    f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t"
                    f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
                )


        # Saving model
        if args.save is not None:
            os.makedirs(args.save, exist_ok=True)
            # model_path = os.path.join(args.save, f'checkpoint_{epoch+1}.pt')
            model_path = os.path.join(args.save, f'checkpoint_latest.pt')
            print('Saving model to', model_path)
            image_classifier.save(model_path)
            # optim_path = os.path.join(args.save, f'optim_{epoch+1}.pt')
            optim_path = os.path.join(args.save, f'optim_latest.pt')
            torch.save(optimizer.state_dict(), optim_path)

        # Evaluate
        args.current_epoch = epoch
        image_classifier = model.module
        eval_results = evaluate(image_classifier, args)

        
    if args.save is not None:
        return model_path


def calculate_entropy(softmax_tensor):
    # 对每个样本的 softmax 输出计算熵
    entropy_list = []
    for i in range(softmax_tensor.size(0)):  # 遍历样本
        softmax_sample = softmax_tensor[i]  # 获取第 i 个样本的 softmax 输出
        entropy_sample = -torch.sum(softmax_sample * torch.log(softmax_sample + 1e-9))  # 计算熵
        entropy_list.append(entropy_sample.item())  # 将熵添加到列表中

    # 计算平均熵
    entropy_mean = torch.mean(torch.tensor(entropy_list))

    return entropy_mean

def calculate_energy(softmax_tensor):
    # 对每个样本的 softmax 输出计算能量
    energy_list = []
    for i in range(softmax_tensor.size(0)):  # 遍历样本
        softmax_sample = softmax_tensor[i]  # 获取第 i 个样本的 softmax 输出
        energy_sample = -torch.log(softmax_sample + 1e-9).sum()  # 计算能量
        energy_list.append(energy_sample.item())  # 将能量添加到列表中

    # 计算平均能量
    energy_mean = torch.mean(torch.tensor(energy_list))

    return energy_mean

def calculate_msp(softmax_tensor):
    # 对每个样本的 softmax 输出计算最大置信度
    msp_list = []
    for i in range(softmax_tensor.size(0)):  # 遍历样本
        softmax_sample = softmax_tensor[i]  # 获取第 i 个样本的 softmax 输出
        msp_sample = torch.max(softmax_sample)  # 计算最大置信度
        msp_list.append(msp_sample.item())  # 将最大置信度添加到列表中

    # 计算平均最大置信度
    msp_mean = torch.mean(torch.tensor(msp_list))

    return msp_mean


