import torch
import torch.nn as nn
import time
from typing import Tuple
import torch
from torch.nn import Module
import torch.nn.functional as F
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import List
import numpy as np
import itertools

from ml_common import get_device, test

cosine_similarity = torch.nn.CosineSimilarity(dim=-1)


def compute_coherence(model: Module, dataloader: DataLoader):
    device = get_device()
    coherence_all = torch.tensor((), device=device)

    for x, _ in dataloader:
        x = x.to(device)
        coherence = model.coherence(x)
        coherence_all = torch.cat((coherence_all, coherence))
    return coherence_all.cpu().numpy().flatten()


# DBG
def simple_loss(
    pred_train_ensemble: torch.Tensor,
    y_train: torch.Tensor,
    criterion: _Loss,
    lamb: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Computes the diversity loss"""
    n_models = pred_train_ensemble.shape[0]
    loss_ood = cs_avg = loss_train = 0.0
    cs_list = []
    for i, pred_train in enumerate(pred_train_ensemble):
        loss_train += criterion(pred_train, y_train)

    return loss_train / n_models


def diversity_loss(
    pred_train_ensemble: torch.Tensor,
    pred_ood_ensemble: torch.Tensor,
    y_train: torch.Tensor,
    criterion: _Loss,
    lamb: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Computes the diversity loss"""
    n_models = pred_train_ensemble.shape[0]
    loss_ood = cs_avg = loss_train = 0.0
    cs_list = []
    for i, pred_train in enumerate(pred_train_ensemble):
        loss_train += criterion(pred_train, y_train)
        pred_i = F.softmax(pred_ood_ensemble[i], dim=-1)
        for j in range(i + 1, n_models):
            pred_j = F.softmax(pred_ood_ensemble[j], dim=-1)  # batch x nclass
            cs = cosine_similarity(pred_i, pred_j)
            cs_list.append(cs.detach())
            loss_ood += torch.exp(cs)  # batch x 1

    if n_models > 1:
        loss_ood = torch.log(loss_ood).mean()
        cs_avg = torch.stack(cs_list).mean()
    else:
        loss_ood = cs_avg = torch.tensor(0.0)

    # loss_train /= n_models
    loss = loss_train + lamb * loss_ood

    return loss, loss_train.detach(), loss_ood.detach(), cs_avg.detach()


def train_diverse_step(
    model_list: List[Module],
    dataloader_train: DataLoader,
    dataloader_ood: DataLoader,
    opt: Optimizer,
    criterion: _Loss,
    lamb: float = 0.0,
    disable_pbar: bool = False,
) -> Tuple[float, float, float, float, float]:
    """
    Train for 1 epoch
    """
    device = get_device()
    for model in model_list:
        model.train()
    running_loss = running_loss_train = correct = 0.0
    n_batches = len(dataloader_train)
    correct = np.zeros(len(model_list))
    dataloader_ood_iter = itertools.cycle(dataloader_ood)

    acc = np.zeros(len(model_list))
    for x_train, y_train in tqdm(
        dataloader_train, ncols=80, disable=disable_pbar, leave=False,
    ):

        x_ood, _ = next(dataloader_ood_iter)
        x_train, y_train, x_ood = (
            x_train.to(device),
            y_train.to(device),
            x_ood.to(device),
        )
        x = torch.cat((x_train, x_ood), dim=0)
        n_train = x_train.shape[0]
        opt.zero_grad()
        pred_train_list, pred_ood_list = [], []
        for i, model in enumerate(model_list):
            pred = model(x)
            pred_train, pred_ood = pred[:n_train], pred[n_train:]

            pred_train_list.append(pred_train)
            pred_ood_list.append(pred_ood)
            pred_train_class = torch.argmax(pred_train, dim=-1)
            correct[i] += (pred_train_class == y_train).sum().item()

        pred_train_ensemble = torch.stack(pred_train_list, dim=0)
        pred_ood_ensemble = torch.stack(pred_ood_list, dim=0)
        loss, loss_train, loss_ood, cs = diversity_loss(
            pred_train_ensemble, pred_ood_ensemble, y_train, criterion, lamb=lamb
        )
        loss.backward()
        opt.step()
        running_loss += loss.item()
        running_loss_train += loss_train.item()

    loss = running_loss / n_batches
    loss_train = running_loss_train / n_batches

    for i in range(len(model_list)):
        acc[i] = correct[i] / len(dataloader_train.dataset)

    return loss, loss_train, loss_ood, np.mean(acc), cs


def train_diverse(
    model_list: List[Module],
    dataloader_train: DataLoader,
    dataloader_ood: DataLoader,
    dataloader_test: DataLoader,
    opt: Optimizer,
    criterion: _Loss,
    epochs: int = 10,
    sch: _LRScheduler = None,
    lamb: float = 0.0,
    disable_pbar: bool = False,
):
    """
    Train model
    """
    device = get_device()
    for model in model_list:
        model.train()
        model = model.to(device)

    for epoch in range(1, epochs + 1):
        s = time.time()
        loss, loss_train, loss_ood, acc_train, cs = train_diverse_step(
            model_list,
            dataloader_train,
            dataloader_ood,
            opt,
            criterion,
            lamb,
            disable_pbar,
        )
        acc_test_list = []
        for model in model_list:
            acc_test = test(model, dataloader_test)
            acc_test_list.append(acc_test)
        acc_test = np.mean(acc_test_list)

        if sch:
            sch.step()

        e = time.time()
        time_epoch = e - s
        print(
            "Epoch: {} loss: {:.3f} loss_train: {:.3f} loss_ood: {:.3f} acc_train: {:.2f}%, acc_test: {:.2f}%, CS: {:.2f}, time: {:.1f}".format(
                epoch,
                loss,
                loss_train,
                loss_ood,
                acc_train * 100,
                acc_test * 100,
                cs,
                time_epoch,
            )
        )
