import torch
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
from tqdm import tqdm
import numpy as np
import ot
import time
import torch.nn as nn

from .base import AdaptMethod
from .utils import get_labeled_data, LabeledDataset
from utils import get_accuracy, comet

class GOAT(AdaptMethod):
    def __init__(self, model):
        super().__init__(model)
        self.dr = self.cf.goat.discard_ratio
        self.bs = self.cf.goat.batch_size
        self.lr = self.cf.goat.adam.learning_rate
        self.wd = self.cf.goat.adam.weight_decay
        self.epochs = self.cf.goat.epochs    
        self.gmd = self.cf.goat.gen_mid_domains
        
        self.encoder = model.encoder
        self.classifier = model.classifier
        
        self.criterion = torch.nn.CrossEntropyLoss()
        
    def gradual_adapt(self, domains: list[Dataset]) -> float:
        print(f"------------ Generate Intermediate domains ------------")
        features = []
        for dataset in domains:
            features.append(get_encoded_dataset(self.encoder, dataset, self.device, self.bs))
        all_features = [features[0]]
        if self.gmd > 0:
            for i in range(len(features) - 1):
                print("┌───────── Generate ────────┐")
                all_features += generate_domains(self.gmd, features[i], features[i + 1])
                print("└───────────────────────────┘")
        else:
            all_features = features
        print(f"-------------- End --------------")
        
        for idx, features in enumerate(all_features):
            if idx == 0:
                continue
            if self.cal_process:
                acc, loss = get_accuracy(all_features[-1], self.classifier, self.device, self.ts_batch_size)
                comet.log_metrics({"Target Domain acc": acc, "Target Domain loss": loss})
                print(f"Target Domain - Acc: {acc:.4f} - Loss: {loss:.4f}")
            print(f"┌────────── Adapt {idx} ─────────┐")
            if idx % (self.gmd + 1) == 0:
                self.adapt(features, domains[idx // (self.gmd + 1)])
            else:
                self.adapt(features)
            print(f"└───────────────────────────────┘")
        
        acc, loss = get_accuracy(all_features[-1], self.classifier, self.device, self.ts_batch_size)
        print(f"Final Acc: {acc:.4f}")
        if self.cal_process:
            comet.log_metrics({"Target Domain acc": acc, "Target Domain loss": loss})
        
        return acc
        
    def adapt(self, features: Dataset, domain: Dataset = None):
        labeled_data = get_labeled_data(features, self.classifier, self.device, self.dr, self.bs)
        loader = DataLoader(labeled_data, batch_size=self.bs)
        self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=self.lr, weight_decay=self.wd)
        self.classifier.train()
        for epoch in range(self.epochs):
            for x in tqdm(loader, desc=f"Epoch {epoch+1}/{self.epochs}", leave=False):
                if len(x) == 2:
                    data, label = x
                elif len(x) == 3:
                    data, label, weight = x
                    weight = weight.to(self.device)
                else:
                    raise ValueError(f"Invalid input length: {len(x)}")
                
                data = data.to(self.device)
                label = label.to(self.device)
                
                output = self.classifier(data)
                if len(x) == 2:
                    loss = self.criterion(output, label)
                elif len(x) == 3:
                    criterion = nn.CrossEntropyLoss(reduction='none')
                    loss = criterion(output, label) * weight
                    loss = loss.mean()
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
            if domain is not None and self.cal_process:
                r_acc, r_loss = self.test(domain)
                print(f"Real-Labeled Epoch {epoch+1}/{self.epochs} - Acc: {r_acc:.4f} - Loss: {r_loss:.4f}")
                comet.log_metrics({"Real-Labeled acc": r_acc,  "Real-Labeled loss": r_loss})
    
def get_encoded_dataset(encoder, dataset, device, batch_size):
    loader = DataLoader(dataset, batch_size=batch_size)
    encoder = encoder.to(device)
    encoder.eval()
    latents, labels = [], []
    with torch.no_grad():
        for data, label in tqdm(loader, desc="Encoding data"):
            data = data.to(device, dtype=torch.float32)
            latent = encoder(data)
            latents.append(latent.cpu())
            labels.append(label.cpu())
    latents = torch.cat(latents)
    labels = torch.cat(labels)
    encoded_dataset = LabeledDataset(latents.detach(), labels.detach())
    return encoded_dataset

# ------------------------------------------------------------

def generate_domains(n_inter, dataset_s, dataset_t, entry_cutoff=0, conf=0):
    all_features = []
    
    xs = dataset_s.data
    xt = dataset_t.data
    ys = dataset_s.labels
    
    max_len = 5000
    if len(xs) > max_len:
        print(f"Downsample source data to {max_len}")
        xs = xs[:max_len]
        xt = xt[:max_len]
        ys = ys[:max_len]
    
    if len(xs.shape) > 2:
        xs_flat, xt_flat = nn.Flatten()(xs), nn.Flatten()(xt)
        plan = get_OT_plan(xs_flat, xt_flat, solver="emd", entry_cutoff=entry_cutoff)
    else:
        plan = get_OT_plan(xs, xt, solver="emd", entry_cutoff=entry_cutoff)
    
    logits_t = get_transported_labels(plan, ys, logit=True)
    yt_hat, conf_idx = get_conf_idx(logits_t, confidence_q=conf)

    xt = xt[conf_idx]
    plan = plan[:, conf_idx]
    yt_hat = yt_hat[conf_idx]
    
    print(f"Remaining data after confidence filter: {len(conf_idx)}")

    for i in range(1, n_inter + 1):
        x, weights = pushforward(xs, xt, plan, i / (n_inter + 1))
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float()
        domain_dataset = LabeledDataset(x, -1 * torch.ones(len(x)), weights)
        all_features.append(domain_dataset)

    all_features.append(dataset_t)

    print(f"Total data for each intermediate domain: {len(x)}")

    return all_features
    # all_datasets = []
    # for domain in all_features:
    #     input_cpu = domain.data
    #     label_cpu = domain.labels
    #     all_datasets.append(EncodedDataset(input_cpu, label_cpu))
    # return all_datasets


def pushforward(X_S, X_T, plan, t):
    print(f"Pushforward to t={t}")
    assert 0 <= t <= 1
    nonzero_indices = np.argwhere(plan > 0)
    weights = plan[plan > 0]
    assert len(nonzero_indices) == len(weights)
    x_t = (1 - t) * X_S[nonzero_indices[:, 0]] + t * X_T[nonzero_indices[:, 1]]

    return x_t, weights

def get_transported_labels(plan, ys, logit=False):
    # plan /= np.sum(plan, 0, keepdims=True)
    ysTemp = ot.utils.label_normalization(np.copy(ys))
    # print(ysTemp)
    # print(f"Unique labels: {np.unique(ysTemp)}")
    classes = np.unique(ysTemp)
    # n = len(classes)
    n = int(classes.max()) + 1
    D1 = np.zeros((n, len(ysTemp)))

    # perform label propagation
    transp = plan

    # set nans to 0
    transp[~np.isfinite(transp)] = 0

    for c in classes:
        D1[int(c), ysTemp == c] = 1

    # compute propagated labels
    transp_ys = np.dot(D1, transp).T

    if logit:
        return transp_ys

    transp_ys = np.argmax(transp_ys, axis=1)

    return transp_ys


def get_conf_idx(logits, confidence_q=0.2):
    confidence = np.amax(logits, axis=1) - np.amin(logits, axis=1)
    alpha = np.quantile(confidence, confidence_q)
    indices = np.argwhere(confidence >= alpha)[:, 0]
    labels = np.argmax(logits, axis=1)

    return labels, indices


def get_OT_plan(
    X_S,
    X_T,
    solver="sinkhorn",
    weights_S=None,
    weights_T=None,
    Y_S=None,
    numItermax=1e7,
    entropy_coef=1,
    entry_cutoff=0,
):

    # X_S, X_T = X_S[:5000], X_T[:5000]
    X_S, X_T = X_S, X_T
    n, m = len(X_S), len(X_T)
    a = np.ones(n) / n if weights_S is None else weights_S
    b = np.ones(m) / m if weights_T is None else weights_T
    print(f"{n} source data, {m} target data. ")
    dist_mat_np = ot.dist(X_S, X_T).detach().numpy()
    # print(f'Distance matrix shape: {dist_mat_np.shape}')
    t = time.time()
    if solver == "emd":
        plan = ot.emd(a, b, dist_mat_np, numItermax=int(numItermax))
    elif solver == "sinkhorn":
        plan = ot.sinkhorn(
            a,
            b,
            dist_mat_np,
            reg=entropy_coef,
            numItermax=int(numItermax),
            stopThr=10e-7,
        )
    elif solver == "lpl1":
        plan = ot.sinkhorn_lpl1_mm(
            a,
            b,
            Y_S,
            dist_mat_np,
            reg=entropy_coef,
            numItermax=int(numItermax),
            stopInnerThr=10e-9,
        )

    if entry_cutoff > 0:
        avg_val = 1 / (n * m)
        print(f"Zero out entries with value < {entry_cutoff}*{avg_val}")
        plan[plan < avg_val * entry_cutoff] = 0

    elapsed = round(time.time() - t, 2)
    print(f"Time for OT calculation: {elapsed}s")
    # plan /= np.sum(plan, 0, keepdims=True)
    # plan[~ np.isfinite(plan)] = 0
    plan = plan * n

    return plan