import torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from attributors.abstract_classes import Property
from attributors.utils import flatten_params
from tqdm import tqdm
from data.utils import process_batch
from transformers import default_data_collator
from torch.utils.data import default_collate

class AccuracyProperty(Property):
    def __init__(self, dataset, batch_size=None):
        super().__init__(dataset, 'accuracy', batch_size)
        if self.dataset.name in ['qnli', 'qnli_noisy']:
            self.collate_fn = default_data_collator
            self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn)
        else:
            self.collate_fn = default_collate
            self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False)

    def forward(self, model, device='cpu'):
        """Compute the property value"""
        return self.loop(model, self.loader, device=device)
    
    def backward(self, model, device='cpu'):
        """Non-differentiable property"""
        return None
    
    def test_output_forward(self, test_outputs, device='cpu'):
        """Compute the property value for test outputs"""
        outputs_loader = DataLoader(test_outputs, batch_size=self.batch_size, shuffle=False)
        n_correct = 0.0
        for batch, outputs in zip(self.loader, outputs_loader):
            _, labels = process_batch(batch, device=device)
            outputs = outputs.to(device)
            predictions = outputs.argmax(dim=-1)
            n_correct += (predictions == labels).sum().item()
            torch.cuda.empty_cache()
        return n_correct / len(self.loader.dataset)
    
    def loop(self, model, loader, device='cpu'):
        """Loop over dataset and compute property value"""
        # Move model to device
        if isinstance(model, nn.Module):
            model = model.to(device).eval()

        # Loop over dataset
        n_correct = 0.0
        for i, batch in enumerate(loader):
            inputs, labels = process_batch(batch, device=device)
            outputs = model(**inputs).to(device)
            predictions = outputs.argmax(dim=-1)
            n_correct += (predictions == labels).sum().item()
            torch.cuda.empty_cache()
        return n_correct / len(loader.dataset)

class CrossEntropyProperty(Property):
    def __init__(self, dataset, batch_size=None):
        super().__init__(dataset, 'cross_entropy', batch_size)
        self.loss = nn.CrossEntropyLoss(reduction='mean')
        if self.dataset.name in ['qnli', 'qnli_noisy']:
            self.collate_fn = default_data_collator
            self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn)
        else:
            self.collate_fn = default_collate
            self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False)
    
    def forward(self, model, device='cpu'):
        """Compute the property value"""
        return -1. * self.loop(model, self.loader, gradients=False, device=device)
    
    def backward(self, model, device='cpu'):
        """Compute the property gradient"""
        return -1. * self.loop(model, self.loader, gradients=True, device=device)
    
    def test_output_forward(self, test_outputs, device='cpu'):
        """Compute the property value for test outputs"""
        outputs_loader = DataLoader(test_outputs, batch_size=self.batch_size, shuffle=False)
        property_val = 0.0
        for batch, outputs in zip(self.loader, outputs_loader):
            inputs, labels = process_batch(batch, device=device)
            outputs = outputs.to(device)
            loss_val = self.loss(outputs, labels.long())
            property_val += loss_val.item() * labels.size(0)
            torch.cuda.empty_cache()
        return -1. * property_val / len(self.loader.dataset)
    
    def group_backward(self, model, group_indices_to_select, device='cpu'):
        """Compute the property gradient for a given group mask"""
        data_subset = Subset(self.dataset, group_indices_to_select)
        loader = DataLoader(data_subset, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn)
        return -1. * self.loop(model, loader, gradients=True, device=device)

    def loop(self, model, loader, gradients=False, device='cpu'):
        """Loop over dataset and compute property value or gradient"""
        
        # Move model to device
        if isinstance(model, nn.Module):
            model = model.to(device).eval()
            
        # Loop over dataset
        res = None if gradients else 0.0
        for i, batch in enumerate(loader):
            inputs, labels = process_batch(batch, device=device)
            outputs = model(**inputs)
            loss_val = self.loss(outputs, labels.long())
            if gradients:
                model.zero_grad()
                loss_val.backward()
                grads = flatten_params(model, gradients=True) * labels.size(0)
                if res is None:
                    res = grads.detach().cpu()
                else:
                    res += grads.detach().cpu()
            else:
                res += loss_val.item() * labels.size(0)
            torch.cuda.empty_cache()
        res /= len(loader.dataset)
        return res.to(device) if gradients else res