import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm, trange
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ExponentialLR

from utils.types import List, Union, Tensor, SeedType
from datasim import Bootstrapper, MultiBootstrapper
from losses import ECELoss, EPS


class C3(nn.Module):

    def __init__(self, layer_nums: List[int], X_init: Union[None, Tensor] = None, y_init: Union[None, Tensor] = None,
                 sigma: float = 0.5, weight_factor: float = 1., clear_buffer: bool = False, seed: SeedType = None):
        super().__init__()

        layers = list()
        for l1, l2 in zip(layer_nums[:-1], layer_nums[1:]):
            layers.append(nn.Linear(l1, l2))
            layers.append(nn.Softplus())
        layers.pop()

        self.project = nn.Sequential(*layers)
        self.X_init = torch.zeros((1, layer_nums[0])) if X_init is None else X_init
        self.y_init = torch.zeros((1, 1)) if y_init is None else y_init

        self.X_buff = None
        self.y_buff = None
        self.w_buff = None
        self.rng = None

        self.sigma = sigma
        self.weight_factor = weight_factor

        if clear_buffer:
            self.clear_buffer(seed=seed)

    def forward(self, Q, K, V):

        Q_proj = self.project(Q)
        K_proj = self.project(K)

        dist = F.pairwise_distance(Q_proj, K_proj)
        S = torch.exp(-dist ** 2 / (2 * self.sigma ** 2 + EPS))

        L = torch.exp(-torch.cdist(K_proj, K_proj) ** 2 / (2 * self.sigma ** 2))
        w = 1 / L.sum(dim=-1)

        pos = (S * w * V.T).sum(dim=-1)
        neg = (S * w * (1. - V.T)).sum(dim=-1)

        mean = pos / ((S * w).sum(dim=-1) + EPS)
        stderr = (pos * neg) / ((pos + neg) ** 2 * (pos + neg + 1) + EPS)

        return mean, torch.sqrt(stderr)

    def infer_batch(self, Q, K=None, V=None, w=None, sigma=None, weight_factor=None, matrix_size=1e6):
        if sigma is None:
            sigma = self.sigma
        if weight_factor is None:
            weight_factor = self.weight_factor

        if K is None:
            K = self.X_buff
            V = self.y_buff
            w = self.w_buff

        with torch.no_grad():
            Q = self.project(Q).unsqueeze(dim=1).numpy()

        X_ref = K.numpy()
        y_ref = V.numpy().reshape(-1)
        w_ref = w.numpy()

        m = X_ref.shape[1]
        n = len(Q)
        batch_size = int(np.ceil(min(matrix_size / n, m)))

        X_batches = np.split(X_ref, np.arange(batch_size, len(X_ref), batch_size))
        y_batches = np.split(y_ref, np.arange(batch_size, len(y_ref), batch_size))
        w_batches = np.split(w_ref, np.arange(batch_size, len(w_ref), batch_size))

        pos = 0.
        neg = 0.
        eta = 0.

        w_new = list()

        for X_batch, y_batch, w_batch in zip(X_batches, y_batches, w_batches):
            score_batch = np.exp(-np.linalg.norm(X_batch - Q, axis=-1) ** 2 / (2 * sigma ** 2 + EPS))

            w_new.append((1 / (1 / w_batch + score_batch.sum(axis=0))))

            eta += score_batch.sum(axis=-1)
            pos += (score_batch * w_batch * y_batch).sum(axis=-1)
            neg += (score_batch * w_batch * (1. - y_batch)).sum(axis=-1)

        total = pos + neg
        pos = eta * pos / (total + EPS)
        neg = eta * neg / (total + EPS)

        pos *= weight_factor
        neg *= weight_factor

        mean = pos / (pos + neg + EPS)
        stderr = (pos * neg) / ((pos + neg) ** 2 * (pos + neg + 1) + EPS)

        sampled_reward = self.rng.beta(pos + 1, neg + 1, size=mean.shape)
        action = np.argmax(sampled_reward)

        w_new = np.concatenate(w_new + [1 / (eta + 1)[[action]]])

        return action, mean, np.sqrt(stderr), w_new

    def store_buffer(self, X_query, y_query, importance_weights=None):
        with torch.no_grad():
            self.X_buff = torch.cat((self.X_buff, self.project(X_query)))
        self.y_buff = torch.cat((self.y_buff, y_query))
        if importance_weights is None:
            self.w_buff = 1 / torch.concat(
            [torch.exp(-torch.cdist(self.X_buff, X_batch) ** 2 / (2 * self.sigma ** 2 + EPS)).sum(dim=0)
             for X_batch in torch.split(self.X_buff, 512)])
        else:
            self.w_buff = importance_weights

    def clear_buffer(self, seed=None, chuck_size=512):
        with torch.no_grad():
            self.X_buff = self.project(self.X_init.clone())  # projected form
        self.y_buff = self.y_init.clone()

        self.w_buff = 1 / torch.concat(
            [torch.exp(-torch.cdist(self.X_buff, X_batch) ** 2 / (2 * self.sigma ** 2 + EPS)).sum(dim=0)
             for X_batch in torch.split(self.X_buff, chuck_size)])
        self.rng = np.random.RandomState(seed=seed)

    def pop_buffer(self, indices, chunk_size=512):
        indices = [indices] if isinstance(indices, int) else indices
        X_out = np.delete(self.X_buff, indices, axis=0)
        y_out = np.delete(self.y_buff, indices, axis=0)

        score = torch.concat([
            torch.exp(-torch.linalg.norm(X - self.X_buff[indices].unsqueeze(dim=1), dim=-1) ** 2 / (2 * self.sigma ** 2)) for X in
            torch.split(X_out, chunk_size)
        ], dim=1)

        # score = np.exp(-np.linalg.norm(X_out - np.expand_dims(self.X_buff[indices], axis=1), axis=-1) ** 2 / (2 * self.sigma ** 2))
        w = 1 / (1 / np.delete(self.w_buff, indices) - score.sum(dim=0))

        self.X_buff = X_out
        self.y_buff = y_out
        self.w_buff = w

    def fit(self, X_train, y_train, prob_train, device, loss_coef, X_val, y_val, prob_val, epochs=10, batch_size=16,
            M_ece=5, base_ratio=0.6, usage_ratio=1.0, lr=8e-3, explr_gamma=0.99,
            val_base_ratio=0.8, val_seed=None, pick_best_val=True,
            plot=True, save_plot_name=None, tqdm_pbar=True):

        self.to(device)
        self.train()

        optim = torch.optim.Adam(self.parameters(), lr=lr)
        scheduler = ExponentialLR(optim, gamma=explr_gamma)
        ece_loss_fn = ECELoss(M_ece)

        train_bce = list()
        train_ece = list()

        val_bce = list()
        val_ece = list()

        coef_bce = loss_coef["bce"]
        coef_ece = loss_coef["ece"]

        val_dataset = Bootstrapper(X_val, y_val, prob_val, usage_ratio=1, base_ratio=val_base_ratio, seed=val_seed)
        val_dl = DataLoader(val_dataset, batch_size=2 * batch_size, pin_memory=True, pin_memory_device="cuda", shuffle=False)
        X_val_ref, y_val_ref, prob_val_ref = val_dataset.get_ref_data()
        X_val_ref = X_val_ref.to(device)
        y_val_ref = y_val_ref.to(device).view(-1)

        best_loss = torch.inf
        best_state_dict = self.state_dict()

        pbar = trange(epochs) if tqdm_pbar else range(epochs)
        for e in pbar:
            cur_bce_loss = 0
            cur_ece_loss = 0

            dataset = Bootstrapper(X_train, y_train, prob_train, base_ratio=base_ratio, usage_ratio=usage_ratio, seed=e)
            dl = DataLoader(dataset, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda", shuffle=True)

            X_ref, y_ref, prob_ref = dataset.get_ref_data()
            X_ref = X_ref.to(device)
            y_ref = y_ref.to(device).view(-1)

            self.train()
            for X_query, y_query, prob_query in dl:
                X_query = X_query.to(device)
                y_query = y_query.to(device)

                mean, _ = self(X_query, X_ref, y_ref.view(-1, 1))
                mean = mean.view(-1, 1)

                loss_bce = F.binary_cross_entropy(mean, y_query)
                loss_ece = ece_loss_fn(mean, y_query)

                loss = coef_bce * loss_bce + coef_ece * loss_ece

                cur_bce_loss += loss_bce.item() * len(y_query)
                cur_ece_loss += loss_ece.item() * len(y_query)

                optim.zero_grad()
                loss.backward()
                optim.step()
            scheduler.step()

            cur_val_bce_loss = 0
            cur_val_ece_loss = 0

            self.eval()
            with torch.no_grad():
                for X_query, y_query, prob_query in val_dl:
                    X_query = X_query.to(device)
                    y_query = y_query.to(device)

                    mean, _ = self(X_query, X_val_ref, y_val_ref.view(-1, 1))
                    mean = mean.view(-1, 1)

                    loss_bce = F.binary_cross_entropy(mean, y_query)
                    loss_ece = ece_loss_fn(mean, y_query)

                    loss = coef_bce * loss_bce + coef_ece * loss_ece

                    if loss.item() < best_loss:  # TODO: needs to be fixed
                        best_loss = loss.item()
                        best_state_dict = self.state_dict()

                    cur_val_bce_loss += loss_bce.item() * len(y_query)
                    cur_val_ece_loss += loss_ece.item() * len(y_query)

            train_bce.append(cur_bce_loss / len(dataset))
            train_ece.append(cur_ece_loss / len(dataset))

            val_bce.append(cur_val_bce_loss / len(val_dataset))
            val_ece.append(cur_val_ece_loss / len(val_dataset))

            if tqdm_pbar:
                pbar.set_description(f"avg loss = {train_bce[-1]:.4f} | avg val loss = {val_bce[-1]:.4f}")

        if tqdm_pbar:
            pbar.close()

        if pick_best_val:
            self.load_state_dict(best_state_dict)

        if plot:
            _, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

            ax1.plot(train_bce, label="transductive CE", color="tab:blue")
            ax1.plot(train_ece, label="calibration error", alpha=0.4, color="tab:orange")
            ax1.plot(val_bce, "--", label="[V] transductive CE", color="tab:blue")
            ax1.plot(val_ece, "--", label="[V] calibration error", alpha=0.4, color="tab:orange")
            ax1.set_title("Unweighted loss")
            ax1.legend()

            ax2.plot(np.array(train_bce) * coef_bce, label="transductive CE", color="tab:blue")
            ax2.plot(np.array(train_ece) * coef_ece, label="calibration error", alpha=0.4, color="tab:orange")
            ax2.plot(np.array(val_bce) * coef_bce, "--", label="[V] transductive CE", color="tab:blue")
            ax2.plot(np.array(val_ece) * coef_ece, "--", label="[V] calibration error", alpha=0.4, color="tab:orange")
            ax2.set_title("Effective loss (weighted by coefficients)")
            ax2.legend()

            plt.tight_layout()
            if save_plot_name is not None:
                plt.savefig(save_plot_name, dpi=250)
                plt.close()

        self.cpu()
        self.eval()

    def refit(self, device, loss_coef, X_val, y_val, prob_val,
              epochs=10, batch_size=16, M_ece=5,
              resample_size=1000, train_bce=None, train_ece=None, val_bce=None, val_ece=None,
              val_base_ratio=0.8, lr=8e-3, val_seed=None):

        self.to(device)
        self.train()

        optim = torch.optim.Adam(self.parameters(), lr=lr)
        ece_loss_fn = ECELoss(M_ece)

        log_losses = train_bce is not None

        coef_bce = loss_coef["bce"]
        coef_ece = loss_coef["ece"]

        dataset = TensorDataset(self.X_buff.unsqueeze(dim=1), self.y_buff, torch.zeros_like(self.y_buff))
        dl = DataLoader(dataset, batch_size=batch_size, pin_memory=True, pin_memory_device="cuda", shuffle=True)

        val_dataset = Bootstrapper(X_val, y_val, prob_val, usage_ratio=1, base_ratio=val_base_ratio, seed=val_seed)
        val_dl = DataLoader(val_dataset, batch_size=2 * batch_size, pin_memory=True, pin_memory_device="cuda", shuffle=False)

        X_val_ref, y_val_ref, _ = val_dataset.get_ref_data()
        X_val_ref = X_val_ref.to(device)
        y_val_ref = y_val_ref.to(device).view(-1)

        indices = np.arange(len(self.X_ref), dtype=int)
        X_ref_full = self.X_ref.to(device)
        y_ref_full = self.y_ref.to(device).view(-1)

        for e in range(epochs):
            cur_bce_loss = 0
            cur_ece_loss = 0

            ref_idx = torch.from_numpy(np.random.RandomState(seed=e).choice(indices, size=min(resample_size, len(indices)), replace=False))
            X_ref = X_ref_full[ref_idx]
            y_ref = y_ref_full[ref_idx]

            for X_query, y_query, prob_query in dl:
                X_query = X_query.to(device)
                y_query = y_query.to(device)

                mean, _ = self(X_query, X_ref, y_ref.view(-1, 1))
                mean = mean.view(-1, 1)

                loss_bce = F.binary_cross_entropy(mean, y_query)
                loss_ece = ece_loss_fn(mean, y_query)

                loss = coef_bce * loss_bce + coef_ece * loss_ece

                cur_bce_loss += loss_bce.item()
                cur_ece_loss += loss_ece.item()

                optim.zero_grad()
                loss.backward()
                optim.step()

            cur_val_bce_loss = 0
            cur_val_ece_loss = 0

            with torch.no_grad():
                for X_query, y_query, prob_query in val_dl:
                    X_query = X_query.to(device)
                    y_query = y_query.to(device)

                    mean, _ = self(X_query, X_val_ref, y_val_ref.view(-1, 1))
                    mean = mean.view(-1, 1)

                    loss_bce = F.binary_cross_entropy(mean, y_query)
                    loss_ece = ece_loss_fn(mean, y_query)

                    cur_val_bce_loss += loss_bce.item()
                    cur_val_ece_loss += loss_ece.item()

            if log_losses:
                train_bce.append(cur_bce_loss / len(dataset))
                train_ece.append(cur_ece_loss / len(dataset))

                val_bce.append(cur_val_bce_loss / len(val_dataset))
                val_ece.append(cur_val_ece_loss / len(val_dataset))

        self.cpu()
        self.eval()

    def fit_multi(self, X_train, y_train, prob_train, device, loss_coef, X_val, y_val, prob_val, epochs=10, batch_size=16,
                  M_ece=5, usage_ratio=1, base_ratio=0.6, lr=8e-3, explr_gamma=0.99,
                  val_base_ratio=0.8, val_seed=None, pick_best_val=False,
                  plot=True, save_plot_name=None, tqdm_pbar=True):

        self.to(device)
        self.train()

        optim = torch.optim.Adam(self.parameters(), lr=lr)
        scheduler = ExponentialLR(optim, gamma=explr_gamma)
        ece_loss_fn = ECELoss(M_ece)

        train_bce = list()
        train_ece = list()

        val_bce = list()
        val_ece = list()

        coef_bce = loss_coef["bce"]
        coef_ece = loss_coef["ece"]

        mboot_val = MultiBootstrapper(X_val, y_val, prob_val, base_ratio=val_base_ratio, seed=val_seed, batch_size=2 * batch_size, shuffle=False)
        boot_val_dl = mboot_val.get_dl()

        best_loss = torch.inf
        best_state_dict = self.state_dict()

        pbar = trange(epochs) if tqdm_pbar else range(epochs)
        for e in pbar:
            cur_bce_loss = 0
            cur_ece_loss = 0

            mboot = MultiBootstrapper(X_train, y_train, prob_train, usage_ratio=usage_ratio, base_ratio=base_ratio,
                                      seed=e, batch_size=batch_size, shuffle=True)
            boot_dl = mboot.get_dl()

            for dl, (X_ref, y_ref, prob_ref) in boot_dl:
                X_ref = X_ref.to(device)
                y_ref = y_ref.to(device).view(-1)

                for X_query, y_query, prob_query in dl:
                    X_query = X_query.to(device)
                    y_query = y_query.to(device)

                    mean, _ = self(X_query, X_ref, y_ref.view(-1, 1))
                    mean = mean.view(-1, 1)

                    loss_bce = F.binary_cross_entropy(mean, y_query)
                    loss_ece = ece_loss_fn(mean, y_query)

                    loss = coef_bce * loss_bce + coef_ece * loss_ece

                    cur_bce_loss += loss_bce.item() * len(y_query)
                    cur_ece_loss += loss_ece.item() * len(y_query)

                    optim.zero_grad()
                    loss.backward()
                    optim.step()
                scheduler.step()

            cur_val_bce_loss = 0
            cur_val_ece_loss = 0

            with torch.no_grad():
                for dl, (X_val_ref, y_val_ref, prob_val_ref) in boot_val_dl:
                    X_val_ref = X_val_ref.to(device)
                    y_val_ref = y_val_ref.to(device).view(-1)

                    for X_query, y_query, prob_query in dl:
                        X_query = X_query.to(device)
                        y_query = y_query.to(device)

                        mean, _ = self(X_query, X_val_ref, y_val_ref.view(-1, 1))
                        mean = mean.view(-1, 1)

                        loss_bce = F.binary_cross_entropy(mean, y_query)
                        loss_ece = ece_loss_fn(mean, y_query)

                        # loss = coef_bce * loss_bce + coef_ece * loss_ece

                        # if loss.item() < best_loss:  # TODO: fix this
                        #     best_loss = loss.item()
                        #     best_state_dict = self.state_dict()

                        cur_val_bce_loss += loss_bce.item() * len(y_query)
                        cur_val_ece_loss += loss_ece.item() * len(y_query)

            if cur_val_bce_loss < best_loss:
                best_loss = cur_val_bce_loss
                best_state_dict = self.state_dict()

            train_bce.append(cur_bce_loss / mboot.total_length)
            train_ece.append(cur_ece_loss / mboot.total_length)

            val_bce.append(cur_val_bce_loss / mboot_val.total_length)
            val_ece.append(cur_val_ece_loss / mboot_val.total_length)

            if tqdm_pbar:
                pbar.set_description(f"avg loss = {train_bce[-1]:.4f} | avg val loss = {val_bce[-1]:.4f}")

        if tqdm_pbar:
            pbar.close()

        if pick_best_val:
            self.load_state_dict(best_state_dict)

        if plot:
            _, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

            ax1.plot(train_bce, label="transductive CE", color="tab:blue")
            ax1.plot(train_ece, label="calibration error", alpha=0.4, color="tab:orange")
            ax1.plot(val_bce, "--", label="[V] transductive CE", color="tab:blue")
            ax1.plot(val_ece, "--", label="[V] calibration error", alpha=0.4, color="tab:orange")
            ax1.set_title("Unweighted loss")
            ax1.legend()

            ax2.plot(np.array(train_bce) * coef_bce, label="transductive CE", color="tab:blue")
            ax2.plot(np.array(train_ece) * coef_ece, label="calibration error", alpha=0.4, color="tab:orange")
            ax2.plot(np.array(val_bce) * coef_bce, "--", label="[V] transductive CE", color="tab:blue")
            ax2.plot(np.array(val_ece) * coef_ece, "--", label="[V] calibration error", alpha=0.4, color="tab:orange")
            ax2.set_title("Effective loss (weighted by coefficients)")
            ax2.legend()

            plt.tight_layout()
            if save_plot_name is not None:
                plt.savefig(save_plot_name, dpi=250)

        self.cpu()
        self.eval()
