import pickle
import tqdm
import re

import torch
import torch.nn as nn
from torch import Tensor
import numpy as np

import foolbox

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from sklearn import metrics

from convexrobust.model.base_certifiable import BaseCertifiable, Certificate, Norm
from convexrobust.model.randsmooth_certifiable import RandsmoothCertifiable
from convexrobust.utils import dirs, file_utils, pretty
from convexrobust.utils import torch_utils as TU

import lib.smoothingSplittingNoise.src.noises as rs_noises


from typing import Type, Dict, List, Optional
from dataclasses import dataclass


def balance_classifier(model, global_params):
    model.to(TU.device())
    model.eval()
    assert model.class_balance == 0.0

    s1, t1 = next(iter(TU.fetch_dataloader(global_params.datamodule.test_dataloader(), 1, False)))
    s2, t2 = next(iter(TU.fetch_dataloader(global_params.datamodule.test_dataloader(), 1, False)))
    assert (s1 == s2).all() and (t1 == t2)

    if isinstance(model, RandsmoothCertifiable):
        K_p, max_steps, finishing_steps = 0.4, 150, 5  # PI controller parameters
        steps = max_steps

        dataloader = global_params.datamodule.test_dataloader()
        signals, _ = TU.fetch_dataloader_batch(dataloader, 100)
        preds = model.forward(model.training_signal_modify(signals))
        print(f'Scaling K_p by {preds.std().item()}')
        K_p *= preds.std().item()

        i = 0
        while i <= steps:
            preds, targets = compute_preds_targets(model, global_params, do_tqdm=True)
            class_0_acc, class_1_acc = compute_class_accuracies(preds, targets, False)
            error = (class_1_acc - class_0_acc).item()
            if i == 0:
                print(f'(Original) class 0 acc: {class_0_acc}, class 1 acc: {class_1_acc}')
                if abs(error) < 0.005:
                    print(f'Initial error very low, returning')
                    return

            model.class_balance += K_p * error
            print(f'Error: {error}, threshold: {model.class_balance.item()}')

            if abs(error) < 0.01 and (steps - i) >= finishing_steps:
                print(f'Achieved threshold, finishing')
                steps = i + finishing_steps
            if abs(error) > 0.01 and (steps - i) < finishing_steps and \
                    i + finishing_steps < max_steps:
                print(f'Threshold violated, refinishing')
                steps = i + finishing_steps

            i += 1
        else:
            print(f'(Final) Class 0 acc: {class_0_acc}, class 1 acc: {class_1_acc}')
            pretty.subsection_print(f'Got optimal balance: {model.class_balance.item()}')
            if abs(error) > 0.02:
                print('UNABLE TO BALANCE')
    else:
        preds, targets = compute_preds_targets(model, global_params)
        class_0_acc, class_1_acc = compute_class_accuracies(preds, targets, True)
        print(f'(Original) class 0 acc: {class_0_acc}, class 1 acc: {class_1_acc}')

        fpr, tpr, thresholds = metrics.roc_curve(TU.numpy(targets), -TU.numpy(preds))
        # Optimize such that class accuracies are balanced
        threshold = thresholds[np.argmin(abs(tpr - (1 - fpr)))]
        preds_shift = preds + threshold
        class_0_acc, class_1_acc = compute_class_accuracies(preds_shift, targets, True)

        print(f'(Final) Class 0 acc: {class_0_acc}, class 1 acc: {class_1_acc}')
        pretty.subsection_print(f'Got optimal balance: {threshold.item()}')
        model.class_balance.fill_(threshold)


def compute_preds_targets(model, global_params, do_tqdm=True):
    with torch.no_grad():
        all_preds, all_targets = [], []

        dataloader = global_params.datamodule.test_dataloader()
        for (signal, target) in TU.fetch_dataloader(dataloader, global_params.eval_n, True):
            signal, target = signal.unsqueeze(0), target.unsqueeze(0)

            if isinstance(model, RandsmoothCertifiable):
                n = 100
                n_scale = (model.n * model.cert_n_scale) // n
                pred = model.predict(signal, n=n, n_scale=n_scale)
            else:
                pred = model.forward(signal)
                if not model.single_logit:
                    pred = pred[:, 0] - pred[:, 1]

            all_preds.append(pred)
            all_targets.append(target)

        return torch.cat(all_preds), torch.cat(all_targets)


def compute_class_accuracies(class_preds, targets, make_hard=False):
    if make_hard:
        class_preds = (class_preds <= 0).long()
    class_0_acc = (class_preds == targets)[targets == 0].float().mean()
    class_1_acc = (class_preds == targets)[targets == 1].float().mean()
    return class_0_acc, class_1_acc


def save_class_balance(class_balance, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    checkpoint['state_dict']['class_balance'].fill_(class_balance)
    torch.save(checkpoint, checkpoint_path)
