import argparse
from copy import deepcopy
from datetime import datetime
from functools import partial
from os import makedirs, path
import random
import time
from typing import Callable, Optional, Tuple
from matplotlib import pyplot as plt
import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm, trange
from geometry import GeometricModel
from experiment import Experiment

import mnist_networks, cifar10_networks

from torchvision import datasets, transforms
from xor_datasets import XorDataset
from xor_networks import xor_net
import plot


def project_kernel(jac: torch.Tensor, direction: torch.Tensor) -> torch.Tensor:
    kernel_basis = torch.qr(jac, some=False).Q[:, jac.shape[1] - 1:]
    coefficients = torch.linalg.lstsq(kernel_basis, direction).solution
    displacement = torch.mv(kernel_basis, coefficients)
    return displacement


def project_tangent(jac: torch.Tensor, direction: torch.Tensor) -> torch.Tensor:
    coefficients = torch.linalg.lstsq(jac.transpose(-1, -2), direction.unsqueeze(-1)).solution.squeeze(-1)
    displacement = torch.einsum("...ak, ...a -> ...k", jac, coefficients)
    return displacement

def constant_direction(
        geo_model: GeometricModel,
        start: torch.Tensor,
        direction: torch.Tensor,
        projection: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
        step_size: float = 0.1,
        steps: int = 1000,
        post_processing: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        verbose: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    direction = torch.flatten(direction)
    points = [start]
    x = start
    p = geo_model.proba(x)
    probability, prediction = torch.max(p, dim=-1)
    probabilities = [probability.item()]
    predictions = [prediction.item()]
    for _ in trange(steps) if verbose else range(steps):
        # noinspection PyTypeChecker
        j = geo_model.jac_proba(x)
        with torch.no_grad():
            j = F.normalize(j.reshape(j.shape[0], -1).T, dim=0)
            displacement = projection(j, direction)
            displacement = F.normalize(displacement, dim=-1).reshape(start.shape)
            x = post_processing(x + step_size * displacement)
            points.append(x.detach())
            p = geo_model.proba(x)
            probability, prediction = torch.max(p, dim=-1)
            probabilities.append(probability.item())
            predictions.append(prediction.item())
    points = torch.stack(points, dim=0)
    probabilities = torch.tensor(probabilities, device=start.device)
    predictions = torch.tensor(predictions, device=start.device)
    return points, probabilities, predictions


def constant_direction_kernel(
        geo_model: GeometricModel,
        start: torch.Tensor,
        direction: torch.Tensor,
        step_size: float = 0.1,
        steps: int = 1000,
        post_processing: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    return constant_direction(
        geo_model,
        start,
        direction,
        projection=project_kernel,
        step_size=step_size,
        steps=steps,
        post_processing=post_processing,
        **kwargs
    )


def constant_direction_tangent(
        geo_model: GeometricModel,
        start: torch.Tensor,
        direction: torch.Tensor,
        step_size: float = 0.1,
        steps: int = 1000,
        post_processing: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    return constant_direction(
        geo_model,
        start,
        direction,
        projection=project_tangent,
        step_size=step_size,
        steps=steps,
        post_processing=post_processing,
        **kwargs
    )


def path(
        geo_model: GeometricModel,
        start: torch.Tensor,
        end: torch.Tensor,
        projection: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
        step_size: float = 0.1,
        steps: int = 100000,
        threshold: float = 1e-17,
        post_processing: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        verbose: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if post_processing is None:
        post_processing = (lambda x: x)

    points = [start]
    x = start
    p = geo_model.proba(x)
    probability, prediction = torch.max(p, dim=-1)
    probabilities = [probability.item()]
    predictions = [prediction.item()]
    distance = torch.norm((end - x).flatten(1), dim=-1)
    if verbose:
        print(f"Start point: predicted {predictions[0]} with probability {probabilities[0]:0.4f}")
        p_end = geo_model.proba(end)
        proba_end, pred_end = torch.max(p_end, dim=-1)
        print(f"End point: predicted {pred_end.item()} with probability {proba_end.item():0.4f}")
        #  plt.imshow(start.squeeze())
        #  plt.show()
        #  plt.imshow(end.squeeze())
        #  plt.show()
        print(
            f'Iteration {len(points) - 1:05d} - Distance {distance.item():.04f} - '
            f'Predicted {predictions[-1]} with probability {probabilities[-1]:0.4f}\r',
            end='',
        )
    while distance > threshold and len(points) < steps + 1:
        # noinspection PyTypeChecker
        j = geo_model.jac_proba(x)
        with torch.no_grad():
            #  j = F.normalize(j.reshape(j.shape[0], -1).T, dim=0)
            j = F.normalize(j, dim=-1, p=2)
            direction = (end - x).flatten(1)
            displacement = projection(j, direction)
            displacement_norm = torch.norm(displacement.flatten(1), dim=-1)
            displacement = F.normalize(displacement, dim=-1).reshape(start.shape)
            dot_prod = torch.einsum("...i, ...i -> ...", (direction - displacement.flatten(1)), (displacement.flatten(1)))
            angle = torch.acos(dot_prod / torch.norm((direction - displacement.flatten(1)), dim=1) / torch.norm(displacement.flatten(1), dim=-1))
            x = post_processing(x + step_size * displacement)
            points.append(x.detach())
            p = geo_model.proba(x)
            probability, prediction = torch.max(p, dim=-1)
            probabilities.append(probability.item())
            predictions.append(prediction.item())
            distance = torch.norm((end - x).flatten(1), dim=-1)
            if verbose:
                print(
                    f'n°{len(points) - 1:05d} - Distance {distance.item():0.4f} - ' +
                    f'Displacement angle: {angle.item():0.4f} - norm: {displacement_norm.item():0.4f} ' +
                    f'Predicted {predictions[-1]} with probability {probabilities[-1]:0.4f}\r',
                    end='',
                )
                #  if len(points) % 1000 == 0:
                    #  plt.imshow(x.squeeze())
                    #  plt.show()
    points = torch.stack(points, dim=0)
    probabilities = torch.tensor(probabilities, device=start.device)
    predictions = torch.tensor(predictions, device=start.device)
    return points, probabilities, predictions


def path_kernel(
        geo_model: GeometricModel,
        start: torch.Tensor,
        end: torch.Tensor,
        step_size: float = 0.1,
        steps: int = 10000,
        threshold: float = 1.0,
        post_processing: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    return path(
        geo_model,
        start,
        end,
        projection=project_kernel,
        step_size=step_size,
        steps=steps,
        threshold=threshold,
        post_processing=post_processing,
        **kwargs
    )


def path_tangent(
        geo_model: GeometricModel,
        start: torch.Tensor,
        end: torch.Tensor,
        step_size: float = 0.1,
        steps: int = 100000,
        threshold: float = 1e-15,
        post_processing: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    return path(
        geo_model,
        start,
        end,
        projection=project_tangent,
        step_size=step_size,
        steps=steps,
        threshold=threshold,
        post_processing=post_processing,
        **kwargs
    )


def simple_domain_projection(
    x: torch.Tensor,
    domain: Tuple[float, float] = (0.0, 1.0),
) -> torch.Tensor:
    new_x = x.clamp(domain[0], domain[1])
    print(f"What did we cut? {x.flatten(1).shape[1] - torch.isclose(new_x, x).int().flatten(1).sum(dim=1).item():03d} / {x.flatten(1).shape[1]} items changed, d(x,x_clamped)={torch.norm(x - new_x, p=float('inf')):0.4e}\r", end="")
    return new_x


def domain_projection(
        x: torch.Tensor, normalization: transforms.Normalize,
        domain: Tuple[float, float] = (0.0, 1.0),
) -> torch.Tensor:
    inf = torch.tensor(domain[0]).repeat(x.shape[0], 1, 1)
    sup = torch.tensor(domain[1]).repeat(x.shape[0], 1, 1)
    normalized_inf = normalization(inf).reshape(x.shape[0]).to(x.device)
    normalized_sup = normalization(sup).reshape(x.shape[0]).to(x.device)
    for i in range(x.shape[0]):
        x[i] = torch.clamp(x[i], normalized_inf[i], normalized_sup[i])
    return x



if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Test if there is one data leaf or multiples.",
    )
    parser.add_argument(
        "--datasets",
        type=str,
        nargs='+',
        default="MNIST",
        choices=['MNIST', 'Letters', 'FashionMNIST', 'KMNIST', 'QMNIST', 'CIFARMNIST', 'XOR', 'XOR3D', 'CIFAR10'],
        metavar='name',
        help="Dataset name to be used.",
    )
    parser.add_argument(
        "--restrict",
        type=int,
        metavar="class",
        default=None,
        help="Class to restrict the main dataset to if needed.",
    )
    parser.add_argument(
        "--nsample",
        type=int,
        metavar='N',
        default=2,
        help="Number of initial points to consider."
    )

    parser.add_argument(
        "--random",
        action="store_true",
        help="Permutes randomly the inputs."
    )

    parser.add_argument(
        "--savedirectory",
        type=str,
        metavar='path',
        default='./output/',
        help="Path to the directory to save the outputs in."
    )

    parser.add_argument(
        "--maxpool",
        action="store_true",
        help="Use the legacy architecture with maxpool2D instead of avgpool2d."
    )
    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Force device to cpu."
    )
    parser.add_argument(
        "--nl",
        type=str,
        metavar='f',
        nargs='+',
        default="ReLU",
        choices=['Sigmoid', 'ReLU', 'GELU'],
        help="Non linearity used by the network."
    )

    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.cpu:
        device = torch.device('cpu')
    print(f"Device: {device}")

    dataset_names = args.datasets
    num_samples = args.nsample
    non_linearities =  args.nl
    if not isinstance(dataset_names, list):
        dataset_names = [dataset_names] * len(non_linearities)
    elif len(dataset_names) == 1:
        dataset_names = dataset_names * len(non_linearities)
    if not isinstance(non_linearities, list):
        non_linearities = [non_linearities] * len(dataset_names)
    elif len(non_linearities) == 1:
        non_linearities = non_linearities * len(dataset_names)
    dtype = torch.float
    restrict_to_class = None

    pool = "maxpool" if args.maxpool else "avgpool"
    date = datetime.now().strftime("%y%m%d-%H%M%S")
    #  savedirectory = args.savedirectory + \
        #  ("" if args.savedirectory[-1] == '/' else '/') + \
        #  f"{'-'.join(dataset_names)}/{task}/{dtype}/" + \
        #  f"{date}_nsample={num_samples}{f'_class={restrict_to_class}' if restrict_to_class is not None else ''}_{pool}_{'-'.join(non_linearities)}/"
    #  if not path.isdir(savedirectory):
        #  makedirs(savedirectory)

    if not args.random:
        seed = 42
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
        np.random.seed(seed)  # Numpy module.
        random.seed(seed)  # Python random module.
        torch.manual_seed(seed)
        torch.backends.cudnn.benchmark = False  # type: ignore
        torch.backends.cudnn.deterministic = True # type: ignore

    experiment_list = []
    for (dataset, non_linearity) in zip(dataset_names, non_linearities):
        print(dataset, non_linearity)
        experiment = Experiment(
            dataset_name=dataset,
            non_linearity=non_linearity,
            adversarial_budget=0,
            dtype=dtype,
            device=device,
            num_samples=num_samples,
            restrict_to_class=restrict_to_class,
            pool=pool,
            random=args.random,
        )
        experiment_list.append(experiment)

    normalize = transforms.Normalize((0.,), (1.,))

    for i, experiment in enumerate(tqdm(experiment_list)):
        print(f"Testing experiment n°{i}: {experiment.dataset_name}")
        path_tangent(experiment.geo_model,
                     experiment.input_points[0].unsqueeze(0),
                     experiment.input_points[1].unsqueeze(0),
                     step_size=1e-1,
                     steps=100000,
                     threshold=1e-16,
                     #  post_processing=partial(domain_projection, normalization=normalize),
                     post_processing=simple_domain_projection,
                     )
        
