import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import re
import threading
import socket
import os
import tqdm

import matplotlib
matplotlib.use('QtAgg')
# matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt

from convexrobust.utils import file_utils


def device():
    return torch.device('cuda')
    # return torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def gpu_n():
    return 1 if torch.cuda.is_available() else 0


def fetch_dataset(dataset, fetch_n):
    signals, targets = next(iter(DataLoader(dataset, batch_size=fetch_n)))
    signals, targets = signals.to(device()), targets.to(device())
    return signals, targets


def fetch_dataloader(dataloader, fetch_n, do_tqdm=False):
    if do_tqdm:
        pbar = tqdm.tqdm(total=fetch_n)

    i = 0
    for (signals, targets) in dataloader:
        for (signal, target) in zip(signals, targets):
            yield signal.to(device()), target.to(device())
            if do_tqdm:
                pbar.update(1)
            i += 1

            if i >= fetch_n:
                if do_tqdm:
                    pbar.close()
                return

    if do_tqdm:
        pbar.close()


def fetch_dataloader_batch(dataloader, fetch_n):
    signals, targets = [], []
    for (signal, target) in fetch_dataloader(dataloader, fetch_n):
        signals.append(signal)
        targets.append(target)

    return torch.stack(signals, dim=0), torch.stack(targets, dim=0)


def numpy(tensor):
    return tensor.detach().cpu().numpy()


def imshow(X):
    if X.shape[0] == 1:
        X = X[0]

    if torch.is_tensor(X):
        X = numpy(X)

    X = np.moveaxis(X, 0, -1)

    plt.figure()
    plt.imshow(X)
    plt.axis('off')
    plt.show()


def norm_ball_conversion_factor(to_norm: float, from_norm: float, dim: int):
    if to_norm <= from_norm:
        return 1

    return 1 / (dim ** (1 / from_norm - 1 / to_norm))


def launch_tensorboard(tensorboard_dir, port, erase=True):
    if erase:
        file_utils.create_empty_directory(tensorboard_dir)

    # Use threading so tensorboard is automatically closed on process end
    command = f'tensorboard --bind_all --port {port} '\
              f'--logdir {tensorboard_dir} > /dev/null '\
              f'--window_title {socket.gethostname()} 2>&1'
    t = threading.Thread(target=os.system, args=(command,))
    t.start()

    print(f'Launching tensorboard on http://localhost:{port}')


class SingleLogitWrapper(nn.Module):
    def __init__(self, module, balanced_forward=True):
        super().__init__()
        self.module = module
        self.balanced_forward = balanced_forward

    def forward(self, x):
        if self.balanced_forward:
            pred = self.module.forward_balanced(x)
        else:
            pred = self.module.forward(x)
        return torch.stack([pred, -pred], dim=1)


class SoftmaxWrapper(nn.Module):
    def __init__(self, module, balanced_forward=True):
        super().__init__()
        self.module = module
        self.balanced_forward = balanced_forward

    def forward(self, x):
        if self.balanced_forward:
            pred = self.module.forward_balanced(x)
        else:
            pred = self.module.forward(x)
        return F.softmax(pred)



def load_model_from_checkpoint(checkpoint_path, blueprint, constructor_params):
    model = blueprint.model_class.load_from_checkpoint(
        checkpoint_path, strict=False, **constructor_params)

    # Handle pytorch None weight bug... https://github.com/pytorch/pytorch/issues/16675
    # Relevant for cayley transform 'alpha' parameter
    state_dict = torch.load(checkpoint_path)['state_dict']
    for (name, data) in state_dict.items():
        if 'alpha' in name:
            # Fix sequential list indexing
            name_eval = 'model.' + re.sub(r'\.(\d+)\.', r'[\1].', name)
            exec(name_eval + ' = nn.Parameter(data, requires_grad=True).to(device())')

    return model
