import gc
import os

import torch
from tqdm import tqdm
from torchvision.datasets import VisionDataset
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

from .torch_utils.losses import get_outputs_loss
from .torch_utils import get_scheduler, get_loss
from .torch_base import TorchModelBase

DEBUG = int(os.getenv("DEBUG", 0))

class TorchBBoxAttackModel(TorchModelBase):
    def __init__(self, *args, **kwargs):
        super(TorchBBoxAttackModel, self).__init__(*args, **kwargs)

    def fit(self, X1, X2, y, with_scheduler=True, verbose=None):
        """
        X, y: nparray
        """
        dataset = TensorDataset(torch.tensor(X1).float(), torch.tensor(X2).float(), torch.tensor(y).long())
        return self.fit_dataset(dataset, verbose=verbose, with_scheduler=with_scheduler)

    def fit_dataset(self, dataset, with_scheduler=True, verbose=None):
        if verbose is None:
            verbose = 0 if not DEBUG else 1
        log_interval = 1

        history = []
        base_loss_fn = get_loss(self.loss_name, reduction="none")
        scheduler = None
        if with_scheduler:
            scheduler = get_scheduler(self.optimizer, n_epochs=self.epochs, loss_name=self.loss_name)

        train_loader = torch.utils.data.DataLoader(dataset,
            batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

        test_loader = None
        if self.tst_ds is not None:
            if isinstance(self.tst_ds, VisionDataset):
                ts_dataset = self.tst_ds
            else:
                tstX1, tstX2, tsty = self.tst_ds
            ts_dataset = TensorDataset(torch.tensor(tstX1).float(), torch.tensor(tstX2).float(), torch.tensor(tsty).long())
            test_loader = torch.utils.data.DataLoader(ts_dataset,
                batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

        for epoch in range(self.start_epoch, self.epochs+1):
            train_loss, train_acc = 0., 0.

            for data in tqdm(train_loader, desc=f"Epoch {epoch}"):
                self.model.train()

                x1, x2, y = (d.to(self.device) for d in data)

                self.optimizer.zero_grad()

                outputs = self.model(x1, x2)
                loss = base_loss_fn(outputs, y)

                loss = loss.mean()
                loss.backward()
                if self.grad_clip != 0.:
                    torch.nn.utils.clip_grad_value_(self.model.parameters(), self.grad_clip)
                self.optimizer.step()

                if (epoch - 1) % log_interval == 0:
                    train_loss += loss.item() * len(x1)
                    train_acc += (outputs.argmax(dim=1)==y).sum().float().item()

                    self.model.eval()
                    if self.eval_callbacks is not None:
                        for cb_fn in self.eval_callbacks:
                            cb_fn(self.model, train_loader, self.device)

            current_lr = self.optimizer.state_dict()['param_groups'][0]['lr']
            if scheduler:
                scheduler.step()
            self.start_epoch = epoch

            if (epoch - 1) % log_interval == 0:
                print(f"current LR: {current_lr}")
                self.model.eval()
                history.append({
                    'epoch': epoch,
                    'lr': current_lr,
                    'trn_loss': train_loss / len(train_loader.dataset),
                    'trn_acc': train_acc / len(train_loader.dataset),
                })
                print('epoch: {}/{}, train loss: {:.3f}, train acc: {:.3f}'.format(
                    epoch, self.epochs, history[-1]['trn_loss'], history[-1]['trn_acc']))

                if self.tst_ds is not None:
                    tst_loss, tst_acc = self._calc_eval(test_loader, base_loss_fn)
                    history[-1]['tst_loss'] = tst_loss
                    history[-1]['tst_acc'] = tst_acc
                    print('             test loss: {:.3f}, test acc: {:.3f}'.format(
                          history[-1]['tst_loss'], history[-1]['tst_acc']))

        if test_loader is not None:
            del test_loader
        del train_loader
        gc.collect()

        return history

    def _calc_eval(self, loader, loss_fn):
        self.model.eval()
        cum_loss, cum_acc = 0., 0.
        with torch.no_grad():
            for data in loader:
                tx1, tx2, ty = data[0], data[1], data[2]
                tx1, tx2, ty = tx1.to(self.device), tx2.to(self.device), ty.to(self.device)
                outputs = self.model(tx1, tx2)
                if loss_fn.reduction == 'none':
                    loss = torch.sum(loss_fn(outputs, ty))
                else:
                    loss = loss_fn(outputs, ty)
                cum_loss += loss.item()
                cum_acc += (outputs.argmax(dim=1)==ty).sum().float().item()
        return cum_loss / len(loader.dataset), cum_acc / len(loader.dataset)

    def predict_real(self, X1, X2):
        self.model.eval()
        dataset = TensorDataset(torch.tensor(X1).float(), torch.tensor(X2).float())
        loader = torch.utils.data.DataLoader(dataset,
            batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
        ret = []
        with torch.no_grad():
            for (x1, x2) in loader:
                ret.append(self.model(x1.to(self.device), x2.to(self.device)).detach().cpu().numpy())
        del loader
        return np.concatenate(ret, axis=0)