from __future__ import annotations
import argparse
from datetime import datetime
import math
import numpy.random as random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as torch_models
import torch
import torch.backends.cudnn as cudnn
import torch.cuda as cuda

import utils.models as models

from typing import Sequence, Any
from argparse import ArgumentParser, Namespace

class ParseStrFloatDict(argparse.Action):
	def __call__(self, parser: ArgumentParser, namespace: Namespace, values: str | Sequence[Any] | None, option_string: str | None = ...) -> None:
		setattr(namespace, self.dest, dict())
		for value in values:
			key, val = value.split(':')
			getattr(namespace, self.dest)[key] = float(val)

def output_current_time():
	now = datetime.now()
	dt_string = now.strftime("%Y%m%d %H:%M:%S")
	print(f"date and time: {dt_string}")


def setup_seed(seed):
    torch.manual_seed(seed)
    cuda.manual_seed_all(seed)
    cuda.manual_seed(seed)
    random.seed(seed)
    cudnn.deterministic = True

def init_net(arch: str, n_classes: int):
	if arch == 'alexnet':
		model = models.AlexNet()
	elif arch == 'resnet18':
		model = torch_models.resnet18(num_classes = n_classes)
	elif arch == 'femnistnet':
		model = models.FemnistNet()
	elif arch == 'squeezenet':
		model = models.SqueezeNet(class_num=n_classes)
	return model

def inference(model: nn.Module, data_tensor: torch.Tensor, label_tensor: torch.Tensor, batch_size: int=64, ):
		n_samples = len(data_tensor)
		n_batches = math.ceil(n_samples / batch_size)
		model.cuda()
		model.eval()
		loss = 0.0
		correct = 0.0
		with torch.no_grad():
			for batch_idx in range(n_batches):
					data = data_tensor[batch_idx * batch_size:(batch_idx + 1) * batch_size].cuda()
					target = label_tensor[batch_idx * batch_size:(batch_idx + 1) * batch_size].cuda()
					output = model(data)
					loss += F.cross_entropy(output, target).item()
					pred = torch.max(output, 1)[1]
					correct += pred.eq(target.view_as(pred)).sum().item()
		loss /= n_samples
		acc = 100. * correct / n_samples
		return acc, loss

# refer to https://github.com/LPD-EPFL/ByzantineMomentum
#----------------------------------------------------------------------------#
# Find the x maximizing a function y = f(x), with (x, y) ∊ ℝ⁺× ℝ

def line_maximize(scape, evals=16, start=0., delta=1., ratio=0.8, tol=1e-5):
	""" Best-effort arg-maximize a scape: ℝ⁺⟶ ℝ, by mere exploration.
	Args:
		scape Function to best-effort arg-maximize
		evals Maximum number of evaluations, must be a positive integer
		start Initial x evaluated, must be a non-negative float
		delta Initial step delta, must be a positive float
		ratio Contraction ratio, must be between 0.5 and 1. (both excluded)
	Returns:
		Best-effort maximizer x under the evaluation budget
	"""
	# Variable setup
	best_x = start
	best_y = scape(best_x)
	evals -= 1
	# Expansion phase
	while evals > 0:
		prop_x = best_x + delta
		prop_y = scape(prop_x)
		evals -= 1
		# Check if best
		if prop_y > best_y + tol:
			best_y = prop_y
			best_x = prop_x
			delta *= 2
		else:
			delta *= ratio
			break
	# Contraction phase
	while evals > 0:
		if prop_x < best_x:
			prop_x += delta
		else:
			x = prop_x - delta
			while x < 0:
				x = (x + prop_x) / 2
			prop_x = x
		prop_y = scape(prop_x)
		evals -= 1
		# Check if best
		if prop_y > best_y + tol:
			best_y = prop_y
			best_x = prop_x
		# Reduce delta
		delta *= ratio
	# Return found maximizer
	return best_x