from typing import Any, List, Union

import numpy as np
import torch
from opacus.grad_sample import AbstractGradSampleModule
from torch.autograd import Variable
from torch.nn import functional as F
from torch.utils.data import DataLoader

from ipp import IPP
from ipp.data_loader import IPPDataLoader
from ipp.optimizer import IPPOptimizer
from utils.evaluate import evaluate


def focal_loss(output, target, gammas=torch.arange(10), alpha=1, device=None):
    gammas = gammas.to(device)
    ce_loss = torch.nn.functional.cross_entropy(output, target, reduction='none')
    pt = torch.exp(-ce_loss)
    focal_loss = alpha * ((1-pt)**gammas[target]) * ce_loss
    return focal_loss


def train_one_epoch_mixed_loss(data_loader: IPPDataLoader,
                    model: AbstractGradSampleModule,
                    optimizer: IPPOptimizer,
                    ipp: IPP,
                    n_steps: int,
                    max_n_steps: int,
                    mu: List[int],
                    device: Any,
                    adaptive_threshold: float,
                    dro: bool,
                    gammas):
    model.train()
    criterion = F.cross_entropy
    losses = []
    for _, (data, target) in enumerate(data_loader):
        curr_mu = mu if isinstance(mu, int) else mu[n_steps]
        indices = target[:, 1]
        target = target[:, 0]
        target = target.type(torch.LongTensor)
        data, target = data.to(device), target.to(device)
        data, target = Variable(data), Variable(target)
        output = model(data)
        batch_size = ipp.get_batch_sizes()[n_steps].item()
        clipping_thresholds = ipp.get_per_sample_clipping_thresholds()[n_steps]
        noise_scale = ipp.noise_scales[n_steps].item()

        batch_clipping_thresholds = clipping_thresholds[indices].to(device)
        loss = focal_loss(output, target, device=device, gammas=gammas)
        loss_clone = loss.clone().detach()
        order = torch.argsort(loss_clone)
        cumsums = torch.cumsum(batch_clipping_thresholds[order], dim=-1)
        k = (cumsums > curr_mu).max(-1, True).indices.item()
        grad_weights = torch.full((len(loss_clone),), 1).to(device)
        grad_weights[order[k]] = ((torch.cumsum(batch_clipping_thresholds[order], dim=-1)[k] - curr_mu) / batch_clipping_thresholds[order[k]]).item()
        del loss_clone

        # DRO
        # gamma = 0.5
        # grad_weights = gamma * torch.exp(gamma * loss.clone().detach())
        loss = torch.mean(torch.topk(loss, max(len(target) - k, 1), sorted=False, dim=0)[0]) 

        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        batch_size = max(batch_size - curr_mu / torch.mean(clipping_thresholds).item(), 0)
        optimizer.step(batch_size,
                       batch_clipping_thresholds,
                       noise_scale, 
                       adaptive_threshold=adaptive_threshold,
                       grad_weights=grad_weights)
        n_steps += 1
        if n_steps == max_n_steps:
            break
    
    loss = np.mean(losses)
    return n_steps, loss


def ipp_train_mixed_loss(data_loader: IPPDataLoader,
                            model: AbstractGradSampleModule,
                            optimizer: IPPOptimizer,
                            test_loader: DataLoader,
                            ipp: IPP,
                            mu: Union[int, List[int]],
                            device: Any,
                            gammas,
                            adaptive_threshold=0,
                            dro=False,
                            max_n_steps=None):
    if not max_n_steps:
        max_n_steps = ipp.get_n_iterations()
    n_epochs = n_steps = 0
    results = []
    while n_steps < max_n_steps:
        n_steps, train_loss = train_one_epoch_mixed_loss(data_loader,
                                                        model,
                                                        optimizer,
                                                        ipp,
                                                        n_steps,
                                                        max_n_steps,
                                                        mu,
                                                        device,
                                                        adaptive_threshold,
                                                        dro, gammas)
            
        performances = evaluate(model, test_loader, device)
        performances['train_loss'] = train_loss
        results.append(performances)
        n_epochs += 1

        print(f"epoch: {n_epochs},\tloss: {round(float(train_loss), 4)},\taccuracy: {round(performances['accuracy'], 2)}")

    return results
