import torch
from tqdm import tqdm
from itertools import cycle
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from ..base import AdaptMethod
from utils import get_accuracy, comet
from .ntk_mmd2 import cal_mmd

class NTK(AdaptMethod):
    def __init__(self, model):
        super().__init__(model)
        self.dr = self.cf.ntk.discard_ratio
        self.bs = self.cf.ntk.batch_size
        self.lr = self.cf.ntk.adam.learning_rate
        self.wd = self.cf.ntk.adam.weight_decay
        self.epochs = self.cf.ntk.epochs
        self.lambda_mmd = self.cf.ntk.lambda_mmd
        self.lambda_src = self.cf.ntk.lambda_src
        self.temp = self.cf.ntk.temperature

        self.encoder = model.encoder
        self.classifier = model.classifier

    def gradual_adapt(self, domains: list[Dataset]) -> float:
        if self.cal_process:
            acc, loss = get_accuracy(domains[0], self.model, self.device, self.ts_batch_size)
            print(f"Source Domain - Acc: {acc:.4f} - Loss: {loss:.4f}")
        for idx in range(1, len(domains)):
            if self.cal_process:
                acc, loss = get_accuracy(domains[-1], self.model, self.device, self.ts_batch_size)
                comet.log_metrics({"Target Domain acc": acc, "Target Domain loss": loss})
                print(f"Target Domain - Acc: {acc:.4f} - Loss: {loss:.4f}")
            print(f"┌────────── Adapt {idx} ─────────┐")
            self.adapt(domains[0], domains[idx])
            print(f"└───────────────────────────────┘")
        acc, loss = get_accuracy(domains[-1], self.model, self.device, self.ts_batch_size)
        print(f"Final Acc: {acc:.4f}")
        if self.cal_process:
            comet.log_metrics({"Target Domain acc": acc, "Target Domain loss": loss})
        return acc
        
    def adapt(self, domain_s: Dataset, domain_t: Dataset):
        pass
        # ! Code will be released when paper is accepted!
    
def softmax_T(x, T):
    return torch.softmax(x / T, dim=0)

def cross_entropy(output, target, weights=None):
    if weights is None:
        return F.cross_entropy(output, target, reduction='none').mean()
    else:
        return (F.cross_entropy(output, target, reduction='none') * weights).sum()