import math
import torch
from transformers import ViTConfig, ViTForImageClassification,CLIPProcessor, CLIPModel
from transformers import Trainer, TrainingArguments
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from torchvision import transforms
from sklearn.metrics import accuracy_score
import numpy as np
from datasets import load_dataset
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split,Subset
import open_clip
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from .classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
_val_transforms_cifar10 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.2435, 0.2616)),
])
_, _, image_processor = open_clip.create_model_and_transforms(
'ViT-L-14', pretrained=None, device='cpu'
)
preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
preprocessor_normalizer = image_processor.transforms[-1]

model_name = "/inspire/hdd/global_user/zhangwanlin-240108540162/offline/clip-vit-large-patch14"
clipmodel = CLIPModel.from_pretrained(model_name)
hf_processor = CLIPProcessor.from_pretrained(model_name)

class ImageNetDataset(ImageFolder):
    """Class to represent the ImageNet1k dataset."""

    def __init__(self, root, **kwargs):
        super().__init__(root=root, **kwargs)

    def __getitem__(self, idx):
        sample, target = super().__getitem__(idx)
        # target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target]
        return sample, target

# PUBLIC
def val_transforms_cifar10(examples):
    examples['pixel_values'] = [_val_transforms_cifar10(img.convert("RGB")) for img in examples['img']]
    return examples

def load_dataset_vit(dataset, seed = 0):
    if dataset=="cifar10":
        train_ds, test_ds = load_dataset("cifar10", split=['train', 'test'])
        splits = train_ds.train_test_split(test_size=0.1, seed=seed)
        train_ds = splits['train']
        val_ds = splits['test']
        # Transforms are done on the fly in a lazy way
        # Setting up the transforms on each dataset
        val_ds.set_transform(val_transforms_cifar10)
        test_ds.set_transform(val_transforms_cifar10)
    else:
        raise NotImplementedError
    return val_ds, test_ds

class HuggingFaceLikeDataset:
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        # 按照后续代码使用的键值对格式构造数据
        return {'pixel_values': image, 'label': label}

def load_dataset_vit_offline(dataset,num_samples ,seed=0):
    if dataset == "cifar10":
        # 定义转换（相当于 val_transforms_cifar10）
        val_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616)),  # CIFAR-10 的均值和标准差
        ])

        # 加载本地 CIFAR-10 数据集
        dataset_path = '/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/zhangwanlin-240108540162/Fusion-OT/otfusion/data'  # 您本地保存 CIFAR-10 的路径
        full_train_dataset = CIFAR10(root=dataset_path, train=True, download=True, transform=val_transforms)
        test_dataset = CIFAR10(root=dataset_path, train=False, download=True, transform=val_transforms)

        # 拆分训练集为训练集和验证集
        train_size = int(0.9 * len(full_train_dataset))
        val_size = len(full_train_dataset) - train_size
        train_dataset, val_dataset = random_split(
            full_train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(seed)
        )
        
        # 使用 HuggingFaceLikeDataset 封装数据集，将其格式变为类似 Hugging Face 的 Dataset
        train_dataset = HuggingFaceLikeDataset(train_dataset)
        val_dataset = HuggingFaceLikeDataset(val_dataset)
        test_dataset = HuggingFaceLikeDataset(test_dataset)

        return train_dataset,val_dataset, test_dataset
    elif dataset == "imagenet":
        dataset_train = ImageNetDataset(
        root='/inspire/hdd/global_user/zhangwanlin-240108540162/imagenet/ILSVRC/Data/CLS-LOC/train',
        transform=preprocessor_without_normalize,
        )  

        dataset_eval = ImageNetDataset(
        root='/inspire/hdd/global_user/zhangwanlin-240108540162/imagenet/ILSVRC/Data/CLS-LOC' + '/val',
        transform=preprocessor_without_normalize,
        )
           
        total_samples = len(dataset_eval)
        indices = torch.randperm(total_samples)[:num_samples]  # 随机抽 200 条
        dataset_eval_subset = Subset(dataset_eval, indices)
        dataset_eval_subset = HuggingFaceLikeDataset(dataset_eval_subset)
        
        dataset_test_subset = HuggingFaceLikeDataset(dataset_eval)
        dataset_train_subset = HuggingFaceLikeDataset(dataset_train)
        # print(dataset_train_subset[1000])
        return dataset_train_subset,dataset_eval_subset, dataset_test_subset
    else:
        raise NotImplementedError


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))


def get_new_model(dataset, patch_size, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, hidden_dropout_prob, attention_probs_dropout_prob):
    # Definining the model from scratch
    if dataset == "cifar10":
        configs = ViTConfig(image_size=32,
                                patch_size = patch_size,
                                num_hidden_layers = num_hidden_layers,
                                num_attention_heads = num_attention_heads,
                                hidden_size = hidden_size,
                                intermediate_size = intermediate_size,
                                hidden_act = 'gelu',
                                hidden_dropout_prob = hidden_dropout_prob,
                                attention_probs_dropout_prob = attention_probs_dropout_prob)

        configs.num_labels = 10
        model = ViTForImageClassification(configs)
    else:
        raise NotImplementedError
    return model


class ClassificationModel(torch.nn.Module):
    def __init__(self, Vmodel, text_embedding=None, resizer=None, logit_scale=True,logit_scaler=1):
        super().__init__()
        # 加载预训练CLIP模型
        self.model = Vmodel
        # self.processor = CLIPProcessor.from_pretrained(model_name)
        
        self.resizer = resizer if resizer is not None else lambda x: x
        self.logit_scale = logit_scale
        self.logit_scaler = logit_scaler
        self.text_embedding = text_embedding
    def forward(self, pixel_values, output_normalize=True):
        # 预处理图像
        assert output_normalize
        image_features = self.model(pixel_values)
        # 计算logits
        logits = image_features @ self.text_embedding
        # 应用logit scale
        if self.logit_scale:
            logits *= self.logit_scaler
        print(logits.shape)
        return logits

def get_clip_clsmodel(Vmodel):    
    template = 'This is a photo of a {}'
    texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
    
    # 分批处理避免内存溢出
    embedding_text_labels_norm = []
    batch_size = 500
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        inputs = hf_processor(text=batch_texts, return_tensors="pt", padding=True)
        text_features = clipmodel.get_text_features(**inputs)
        text_features = F.normalize(text_features, dim=-1)
        embedding_text_labels_norm.append(text_features.detach())
    embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T
    # embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to("cuda:0")
    assert torch.allclose(
        F.normalize(embedding_text_labels_norm, dim=0),
        embedding_text_labels_norm
    )
    # 验证每一列（每个类别的 embedding）是否已经是单位向量。
    
    model = ClassificationModel(
        Vmodel=Vmodel,
        text_embedding=embedding_text_labels_norm,
        resizer=None,
        logit_scale=True,
        logit_scaler = clipmodel.logit_scale.exp().detach(),
    )
    return model


def get_model(path):
    return ViTForImageClassification.from_pretrained(path)


class ClipVisionModel(torch.nn.Module):
    def __init__(self, visual_model, proj=None, normalize=None, freeze_visual=False):
        """
        Args:
            visual_model: CLIP 视觉编码器 (CLIPVisionModel)
            proj: 投影层 (nn.Linear), 如果为 None 则不使用投影
            normalize: 归一化函数 (通常为 CLIP 的预处理归一化)
            freeze_visual: 是否冻结视觉编码器
        """
        super().__init__()
        self.visual_model = visual_model
        self.proj = proj
        self.normalize = normalize
        self.config = None
        # 冻结视觉编码器（可选）
        if freeze_visual:
            for param in self.visual_model.parameters():
                param.requires_grad = False

    def forward(self, pixel_values, output_normalize=True, return_pooled=True):
        """
        Args:
            pixel_values: 图像像素值 (B, C, H, W)
            output_normalize: 是否对输出归一化
            return_pooled: 是否返回 [CLS] pooled 特征 (否则返回全部特征)
        Returns:
            embedding: 视觉特征 (B, embed_dim) 或 (B, seq_len, embed_dim)
        """

        pixel_values = self.normalize(pixel_values)
        outputs = self.visual_model(pixel_values, output_hidden_states=False)
        
        # 获取特征
        if return_pooled:
            embedding = outputs.pooler_output  # (B, embed_dim)
        else:
            embedding = outputs.last_hidden_state  # (B, seq_len, embed_dim)
        embedding = self.proj(embedding)
        # 归一化输出（可选）
        if output_normalize:
            embedding = F.normalize(embedding, dim=-1)
        return embedding
    def save_components(self, save_dir):
        """正确的保存实现"""
        # os.makedirs(save_dir, exist_ok=True)
        
        # 保存视觉部分的状态字典
        torch.save({
            'vision_model': self.visual_model.state_dict(),
            'proj': self.proj.state_dict() if hasattr(self, 'proj') else None
        }, f"{save_dir}")


def get_clip_model(path):
    try:  # try loading only visual model
        if "ViT-L-14" in path:
            model_name = "/inspire/hdd/global_user/zhangwanlin-240108540162/offline/clip-vit-large-patch14"
        else:
            model_name = "/inspire/hdd/global_user/zhangwanlin-240108540162/offline/clip-vit-base-patch32"
        model = CLIPModel.from_pretrained(model_name)
        hf_processor = CLIPProcessor.from_pretrained(model_name)
        # 从训练的 visual model 中，只加载visual model. 
        if isinstance(path, str) and '/output/' in path and '/checkpoints/' in path:
            if isinstance(path, str):
                checkpoint = torch.load(path, map_location=torch.device('cpu'))
                # for param_name in checkpoint.keys():
                #     print(param_name)
            else:
                checkpoint = path
            # if beta non-zero interpolate between clean and pretrained ckpts
            model.vision_model.load_state_dict(checkpoint['vision_model'])
            model.visual_projection.load_state_dict(checkpoint['proj'])

            Vmodel = ClipVisionModel(visual_model=model.vision_model,  proj=model.visual_projection, normalize=preprocessor_normalizer)
            Vmodel.config = model.config
    except RuntimeError as e:  # try loading whole model
        print(f'error: {e}', file=sys.stderr)
        print('retrying by loading whole model..', file=sys.stderr)
        torch.cuda.empty_cache()
        Vmodel, _, image_processor = open_clip.create_model_and_transforms(
            clip_model_name, pretrained=path, force_quick_gelu=True, device='cpu'
        )

    return Vmodel


def compute_tot_iters(ds, epoches, bs, grad_acc_steps):
    tot_iters = len(ds)*epoches//bs//grad_acc_steps
    return tot_iters


def get_cosine_lr_wup_rstr(opt, ds, epoches, bs, wup_ratio, num_cycles, grad_acc_steps):
    tot_iters = compute_tot_iters(ds, epoches, bs, grad_acc_steps)
    lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer = opt, 
        num_warmup_steps= int(wup_ratio*tot_iters),
        num_training_steps = tot_iters,
        num_cycles = num_cycles
    )
    return lr_scheduler


def get_train_args(training_name, wup_ratio, lr, train_bs, eval_bs, epochs, wd, n_workers, grad_acc_steps, label_smoothing, seed, train_ds_len, report_to="wandb"):
    steps_per_epoches = math.ceil(train_ds_len/(train_bs*grad_acc_steps))
    n_epochs_save = 10
    args = TrainingArguments(
        training_name,
        save_strategy="steps",
        save_steps=steps_per_epoches*n_epochs_save,
        evaluation_strategy="epoch",
        lr_scheduler_type="cosine",
        warmup_ratio=wup_ratio,
        learning_rate=lr,
        per_device_train_batch_size=train_bs,
        per_device_eval_batch_size=eval_bs,
        num_train_epochs=epochs,
        weight_decay=wd,
        load_best_model_at_end=False,
        metric_for_best_model="accuracy",
        logging_dir='logs',
        remove_unused_columns=False,
        dataloader_num_workers=n_workers,
        gradient_accumulation_steps=grad_acc_steps,
        report_to=report_to,
        label_smoothing_factor=label_smoothing,
        logging_steps=steps_per_epoches,
        seed=seed
    )
    return args


def evaluate_vit(model, test_ds):
    args = TrainingArguments(
        "eval_temp",
        save_strategy="epoch",
        evaluation_strategy="epoch",
        per_device_eval_batch_size=128,
        metric_for_best_model="accuracy",
        logging_dir='logs',
        remove_unused_columns=False,
        dataloader_num_workers=16,
    )

    trainer = Trainer(
        model,
        args,
        eval_dataset=test_ds,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
    )

    outputs = trainer.predict(test_ds)

    return outputs.metrics['test_accuracy']