import copy
import csv
import os
import numpy as np
import torch
import torch.nn.functional as F
from scipy import stats
from sklearn.metrics import pairwise_distances
from tqdm import tqdm
from .strategy import Strategy
from torchvision import transforms
from PIL import Image

class SharpnessSampling(Strategy):
    def __init__(self, train_args, unlabeled_originlabels, unlabeled_img_path, add_ratio, num_classes, unlabeled_target,
                 backbone, pool_batch_size=64, rho=0.5, acq_mode='Max'):
        super(SharpnessSampling, self).__init__(
            train_args, unlabeled_originlabels, unlabeled_img_path, add_ratio, num_classes, unlabeled_target)
        self.backbone = backbone.eval()
        self.pool_batch_size = pool_batch_size
        self.rho = rho
        self.acq_mode = acq_mode  # 'Max' or 'Diff'
        self.unlabeled_img_path = unlabeled_img_path

        # transform_test 和 Path
        if train_args.dataset_name == 'Imagette':
            transform_test = transforms.Compose([
                transforms.Resize((128, 128)), 
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
            Path = '' 
        elif train_args.dataset_name == 'Cifar10':

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
            Path = '/home/star/Data/g2/gyh/Cifar10/images/'
        else:

            print(f"Warning: No specific resize rule for dataset: {train_args.dataset_name}. Using default transform.")
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
            Path = '' 

        temporary_unlabeled_img_path = [os.path.join(Path, p) for p in unlabeled_img_path]

        print("Constructing unlabeled_datalist...")

        unlabeled_datalist = torch.stack([
            transform_test(Image.open(path).convert("RGB")) for path in temporary_unlabeled_img_path
        ])
        print(f"unlabeled_datalist constructed. Total images: {len(unlabeled_datalist)}")

        self.pool_data = unlabeled_datalist  # Tensor[N,C,H,W]

    def query(self):
        device = next(self.backbone.parameters()).device
        original_loss_list = []
        perturbed_loss_list = []
        pseudo_labels = torch.zeros(self.pool_data.size(0), dtype=torch.long)

        dataloader = torch.utils.data.DataLoader(self.pool_data, batch_size=self.pool_batch_size, shuffle=False)

        print("Computing original and perturbed loss:")
        with torch.no_grad():
            for i, batch in enumerate(tqdm(dataloader, desc="Evaluating batches")):
                inputs = batch.to(device)
                outputs, _ = self.backbone(inputs)
                probs = F.softmax(outputs.detach().cpu(), dim=1)
                preds = probs.argmax(dim=1)

                pseudo_labels[i * self.pool_batch_size:(i + 1) * self.pool_batch_size] = preds

                criterion = torch.nn.CrossEntropyLoss(reduction='none')
                loss = criterion(outputs, preds.to(device))
                original_loss_list.append(loss.detach().cpu())

        original_loss = torch.cat(original_loss_list)

        print("Applying sharpness-aware perturbation:")
        for i, batch in enumerate(tqdm(dataloader, desc="Perturbing model")):
            inputs = batch.to(device)
            preds = pseudo_labels[i * self.pool_batch_size:(i + 1) * self.pool_batch_size].to(device)

            model_copy = copy.deepcopy(self.backbone).to(device)
            model_copy.train()

            outputs, _ = model_copy(inputs)
            criterion = torch.nn.CrossEntropyLoss(reduction='none')
            loss = criterion(outputs, preds)
            loss.mean().backward()

            norm = torch.norm(torch.stack([
                (p.grad * torch.abs(p)).norm(2) for p in model_copy.parameters() if p.grad is not None
            ]), p=2)
            scale = self.rho / (norm + 1e-12)

            with torch.no_grad():
                for p in model_copy.parameters():
                    if p.grad is not None:
                        e_w = (p ** 2) * p.grad * scale
                        p.add_(e_w)

            outputs_perturbed, _ = model_copy(inputs)
            loss_perturbed = criterion(outputs_perturbed, preds)
            perturbed_loss_list.append(loss_perturbed.detach().cpu())

        perturbed_loss = torch.cat(perturbed_loss_list)

        # Compute acquisition score
        if self.acq_mode == 'Max':
            score = perturbed_loss
        elif self.acq_mode == 'Diff':
            score = perturbed_loss - original_loss
        else:
            raise ValueError(f"Invalid acquisition mode: {self.acq_mode}")

        topk = int(self.add_ratio * self.num_train_set)
        selected_indices = torch.topk(score, topk).indices

        select_data = [(self.unlabeled_img_path[i], str(self.unlabeled_originlabels[i])) for i in selected_indices]
        self._save_selection(select_data)

        return select_data

    def _save_selection(self, select_data):
        filepath = './Selcetion/{}/{}/{}/{}/'.format(self.dataset_name, self.classifier_name,
                                                     self.select_strategy, self.select_type)
        os.makedirs(filepath, exist_ok=True)
        csv_path = '{}{}_{}.csv'.format(filepath, self.dataset_name, self.select_ratio)
        with open(csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            for path, label in select_data:
                writer.writerow([path, label])
