from typing import Any

import torch
import torch.nn as nn
import numpy as np
from .base_postprocessor import BasePostprocessor
from torch.utils.data import DataLoader
import torch.nn.functional as F
import csv

class SP_Postprocessor(BasePostprocessor):
    def __init__(self, config):
        super().__init__(config)
        self.args = self.config.postprocessor.postprocessor_args
        self.temperature = self.args.temperature
        self.mode = self.config.postprocessor.mode

    @torch.no_grad()
    def postprocess_energy(self, nets, data: Any):
        logit = nets['net'](data) #logit
        score = torch.softmax(logit, dim=1)
        _, pred = torch.max(score, dim=1)
        energy = self.temperature * torch.logsumexp(logit / self.temperature, dim=1) # energy
        # energy_score = energy / (1 + energy)
        return logit, pred, energy
    
    @torch.no_grad()
    # def postprocess_loss(self,nets, data: Any, label: Any):
    #     logit = nets['net'](data)
    #     prob = F.softmax(logit, dim=1)
    #     loss = F.cross_entropy(logit, label)
    #     # pred = logit.data.max(1)[1]
    #     _, pred = torch.max(logit, dim=1)
    #     return prob, pred, loss

    @torch.no_grad()
    def postprocess_loss(self, nets, data: Any, label: Any):
        logit = nets['net'](data)
        prob = F.softmax(logit, dim=1)

        # Use class-1 probability and threshold it at 0.7
        class1_prob = prob[:, 1]  # assuming class 1 is at index 1
        pred = (class1_prob >= 0.5).long()
        loss = F.cross_entropy(logit, label)
        return prob, pred, loss


    
    def inference(self, nets, data_loader: DataLoader, mode: str, eval : bool = True, cat=None):
        if eval:
            data_choise = 'data_aux'
        else:
            data_choise = 'data'
        pred_list, logit_list, label_list, score_list, filename_list = [], [], [], [], []
        if self.mode != mode:
            self.mode = mode
        for batch in data_loader:
            data = batch[data_choise].cuda()
            label = batch['label'].cuda()
            filenames = batch['image_name']
            if self.mode == 'loss':
                logit, pred, score = self.postprocess_loss(nets, data, label)
            else:
                print('loss')
                logit, pred, score = self.postprocess_energy(nets, data)
                
            for idx in range(len(data)):
                pred_list.append(pred[idx].cpu().tolist())
                logit_list.append(logit[idx].cpu().tolist())
                label_list.append(label[idx].cpu().tolist())
                filename_list.append(filenames[idx])
            if self.mode == 'loss':
                score_list.append(score.cpu().tolist())
            else:
                for idx in range(len(data)):
                    score_list.append(score[idx].cpu().tolist())

        # convert values into numpy array
        pred_list = np.array(pred_list, dtype=int)
        logit_list = np.array(logit_list)
        label_list = np.array(label_list, dtype=int)
        score_list = np.array(score_list)

        with open(f'{cat}_classified_images.csv', "w", newline='') as f:
            writer = csv.writer(f)
            # writer.writerow(["Image", "Ground Truth", "Prediction", "Class1_Prob"])
            for i in range(len(pred_list)):
                if (label_list[i] == pred_list[i]) :
                    writer.writerow([filename_list[i]])


        return pred_list, logit_list, label_list, score_list
