from typing import Any, List, Union

import gc
import numpy as np
import torch
from opacus.grad_sample import AbstractGradSampleModule
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils.data import DataLoader

from ipp import IPP
from ipp.data_loader import IPPDataLoader
from ipp.optimizer import IPPOptimizer
from utils.evaluate import evaluate


def _get_modified_clipping_thresholds(batch_losses: List[float], 
                                      batch_clipping_thresholds: List[float], 
                                      batch_mus: List[float]) -> List[float]:
    """
    Modifies the batch per-sample clipping thresholds based on the INO-SGD algorithm.

    Args:
        batch_losses: Individual losses corresponding to each datum in the batch.
        batch_clipping_thresholds: Clipping threshold associated with each datum in the batch.
        batch_mus: Specifies the partition.

    Returns:
        List[float]: The modified per-sample clipping thresholds.
    """
    if len(batch_mus) == 1 and batch_mus[0].item() == 0:
        return batch_clipping_thresholds
    
    order = torch.argsort(batch_losses).cpu()
    n_levels = len(batch_mus)
    threshold_multipliers = 1.0 / torch.pow(torch.full((n_levels + 1,), 2), torch.arange(n_levels + 1))
    threshold_multipliers[0] = 1.0
    threshold_multipliers[-1] = 0.0
    threshold_multipliers = torch.flip(threshold_multipliers, dims=(0,))
    modified_clipping_thresholds = torch.zeros((len(batch_clipping_thresholds),))
    curr_n_mu = curr_accummulated_threshold = 0
    batch_mus = torch.cat((batch_mus, torch.Tensor([torch.inf])))
    for o in range(len(batch_losses)):
        curr_threshold = batch_clipping_thresholds[order[o]].cpu()
        if curr_accummulated_threshold + curr_threshold < batch_mus[curr_n_mu]:
            curr_accummulated_threshold += curr_threshold
            modified_clipping_thresholds[order[o]] += curr_threshold * threshold_multipliers[curr_n_mu]
        elif curr_accummulated_threshold + curr_threshold == batch_mus[curr_n_mu]:
            modified_clipping_thresholds[order[o]] += curr_threshold * threshold_multipliers[curr_n_mu]
            curr_n_mu += 1
            curr_accummulated_threshold = 0
        else:
            remaining_threshold = batch_mus[curr_n_mu] - curr_accummulated_threshold
            modified_clipping_thresholds[order[o]] += remaining_threshold * threshold_multipliers[curr_n_mu]
            curr_n_mu += 1
            curr_accummulated_threshold = 0
            modified_clipping_thresholds[order[o]] += (batch_clipping_thresholds[order[o]].cpu() - remaining_threshold) * threshold_multipliers[curr_n_mu]
            
    #del order, n_levels, threshold_multipliers, batch_losses, batch_clipping_thresholds
    #gc.collect()
    #torch.cuda.empty_cache()
    return modified_clipping_thresholds


def train_one_epoch(data_loader: IPPDataLoader,
                    model: AbstractGradSampleModule,
                    optimizer: IPPOptimizer,
                    ipp: IPP,
                    n_steps: int,
                    max_n_steps: int,
                    mu: List[int],
                    device: Any,
                    adaptive_threshold: float):
    model.train()
    criterion = F.cross_entropy
    losses = []
    for _, (data, target) in enumerate(data_loader):
        batch_mus = mu[n_steps]
        indices = target[:, 1]
        target = target[:, 0]

        #counts = torch.bincount((target == 4).int())
        #print(counts[1] / (counts[0] + counts[1]))

        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data, target = Variable(data), Variable(target)
        output = model(data)
        batch_size = ipp.get_batch_sizes()[n_steps].item()
        clipping_thresholds = ipp.get_per_sample_clipping_thresholds()[n_steps]
        noise_scale = ipp.noise_scales[n_steps].item()
        

        batch_clipping_thresholds = clipping_thresholds[indices].to(device)
        loss = criterion(output, target, reduction='none')
        loss_clone = loss.clone().detach()
        modified_clipping_thresholds = _get_modified_clipping_thresholds(loss_clone, batch_clipping_thresholds, batch_mus)
        
        del loss_clone

        loss = torch.mean(loss)

        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step(batch_size,
                       modified_clipping_thresholds.to(device),
                       noise_scale, 
                       adaptive_threshold=adaptive_threshold)
        
        del modified_clipping_thresholds

        n_steps += 1
        if n_steps == max_n_steps:
            break
    
    loss = np.mean(losses)
    return n_steps, loss


def ipp_train(data_loader: IPPDataLoader,
              model: AbstractGradSampleModule,
              optimizer: IPPOptimizer,
              test_loader: DataLoader,
              ipp: IPP,
              mu: Union[int, List[int]],
              device: Any,
              adaptive_threshold=0,
              max_n_steps=None):
    if not max_n_steps:
        max_n_steps = ipp.get_n_iterations()
    n_epochs = n_steps = 0
    results = []
    while n_steps < max_n_steps:
        n_steps, train_loss = train_one_epoch(data_loader,
                                              model,
                                              optimizer,
                                              ipp,
                                              n_steps,
                                              max_n_steps,
                                              mu,
                                              device,
                                              adaptive_threshold)
            
        performances = evaluate(model, test_loader, device)
        performances['train_loss'] = train_loss
        results.append(performances)
        n_epochs += 1

        print(f"epoch: {n_epochs},\tloss: {round(float(train_loss), 4)},\taccuracy: {round(performances['accuracy'], 2)}")

    return results
