from typing import Any, List, Union

import gc
import numpy as np
import torch
from opacus.grad_sample import AbstractGradSampleModule
from scipy.stats import beta
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils.data import DataLoader

from ipp import IPP
from ipp.data_loader import IPPDataLoader
from ipp.optimizer import IPPOptimizer
from utils.evaluate import evaluate


def _get_gradient_multipliers(batch_losses: List[float], 
                              batch_clipping_thresholds: List[float], 
                              a, b, scale) -> List[float]:
    """
    Modifies the batch per-sample clipping thresholds based on the INO-SGD algorithm.

    Args:
        batch_losses: Individual losses corresponding to each datum in the batch.
        batch_clipping_thresholds: Clipping threshold associated with each datum in the batch.
        batch_mus: Specifies the partition.

    Returns:
        List[float]: The modified per-sample clipping thresholds.
    """
    #if len(batch_mus) == 1 and batch_mus[0].item() == 0:
    #    return batch_clipping_thresholds
    
    order = torch.argsort(batch_losses).cpu()
    accumulated_threshold = 0
    gradient_multipliers = torch.zeros((len(batch_losses),))
    for o in range(len(batch_losses)):
        curr_threshold = batch_clipping_thresholds[order[o]].cpu()
        pts = torch.linspace(accumulated_threshold, accumulated_threshold + curr_threshold, steps=100)
        pts_values = torch.Tensor(beta.cdf(pts, a, b, scale=scale))
        gradient_multipliers[order[o]] = torch.trapezoid(pts_values, x=pts) / curr_threshold
        accumulated_threshold += curr_threshold

    return gradient_multipliers


def train_one_epoch_beta(data_loader: IPPDataLoader,
                    model: AbstractGradSampleModule,
                    optimizer: IPPOptimizer,
                    ipp: IPP,
                    n_steps: int,
                    max_n_steps: int,
                    #mu: List[int],
                    device: Any,
                    adaptive_threshold: float,
                    *, a, b, scale):
    model.train()
    criterion = F.cross_entropy
    losses = []
    for _, (data, target) in enumerate(data_loader):
        
        #batch_mus = mu[n_steps]
        indices = target[:, 1]
        target = target[:, 0]
        
        #counts = torch.bincount((target == 4).int())
        #print(counts[1] / (counts[0] + counts[1]))

        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data, target = Variable(data), Variable(target)
        
        output = model(data)
        batch_size = ipp.get_batch_sizes()[n_steps].item()
        clipping_thresholds = ipp.get_per_sample_clipping_thresholds()[n_steps]
        noise_scale = ipp.noise_scales[n_steps].item()
        

        batch_clipping_thresholds = clipping_thresholds[indices].to(device)
        loss = criterion(output, target, reduction='none')
        loss_clone = loss.clone().detach()
        gradient_multipliers = _get_gradient_multipliers(loss_clone, batch_clipping_thresholds, #batch_mus,
                                                         a, b, scale)
        
        del loss_clone

        loss = torch.mean(loss)

        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()

        expected_total_clipping_threshold = (ipp.get_per_sample_sampling_rates()[n_steps] * clipping_thresholds).sum()
        tail_clipping_threshold = scale
        pts = torch.linspace(0, tail_clipping_threshold, steps=1000)
        pts_values = torch.Tensor(beta.cdf(pts, a, b, scale=scale))
        integration = torch.trapezoid(pts_values, x=pts)
        #batch_size_estimate = batch_size * (expected_total_clipping_threshold - (tail_clipping_threshold - integration)) / expected_total_clipping_threshold
        #print(batch_size_estimate)
        optimizer.step(batch_size,
                       batch_clipping_thresholds,
                       noise_scale, 
                       gradient_multipliers=gradient_multipliers.to(device),
                       adaptive_threshold=adaptive_threshold)
        
        del gradient_multipliers

        n_steps += 1
        if n_steps == max_n_steps:
            break
    
    loss = np.mean(losses)
    return n_steps, loss


def ipp_train_beta(data_loader: IPPDataLoader,
              model: AbstractGradSampleModule,
              optimizer: IPPOptimizer,
              test_loader: DataLoader,
              ipp: IPP,
              #mu: Union[int, List[int]],
              device: Any,
              adaptive_threshold=0,
              max_n_steps=None,
              *, a, b, scale):
    if not max_n_steps:
        max_n_steps = ipp.get_n_iterations()
    n_epochs = n_steps = 0
    results = []
    while n_steps < max_n_steps:
        n_steps, train_loss = train_one_epoch_beta(data_loader,
                                              model,
                                              optimizer,
                                              ipp,
                                              n_steps,
                                              max_n_steps,
                                              #mu,
                                              device,
                                              adaptive_threshold,
                                              a=a,
                                              b=b,
                                              scale=scale)
            
        performances = evaluate(model, test_loader, device)
        performances['train_loss'] = train_loss
        results.append(performances)
        n_epochs += 1

        print(f"epoch: {n_epochs},\tloss: {round(float(train_loss), 4)},\taccuracy: {round(performances['accuracy'], 2)}")

    return results
