from abc import ABC, abstractmethod

from torch.utils.data import Dataset, DataLoader
from utils import get_accuracy, comet
from config import get_config, cf


def cal_config(data_name, model_name, method_name, domain_num):
    pri_path = [f"config/{data_name}/{model_name}/adapt_{domain_num}.yaml",
                f"config/{data_name}/{model_name}/adapt.yaml",
                f"config/{data_name}/adapt.yaml",
                f"config/adapt.yaml"]
    for path in pri_path:
        try:
            cf = get_config([path])
            if not hasattr(cf, method_name.lower()):
                raise ValueError(f"No config file found for {data_name} {model_name} {method_name} {domain_num} error: {e}")
            print(f"using AdaptMethod config: '{path}'")
            return cf
        except Exception as e:
            continue
    raise ValueError(f"No config file found for {data_name} {model_name} {method_name} {domain_num} error: {e}")

class AdaptMethod(ABC):
    def __init__(self, model):
        self.cf = cal_config(cf.data_name, cf.model_name, cf.method_name, cf.domain_num)
        self.model = model
        self.device = cf.device
        
        self.ts_batch_size = self.cf.test.batch_size
        self.cal_process = self.cf.cal_process
        
    def test(self, domain: Dataset):
        loader = DataLoader(domain, batch_size=self.ts_batch_size)
        acc, loss = get_accuracy(loader, self.model, self.device)
        return acc, loss
    
    @abstractmethod
    def adapt(self, domain: Dataset):
        pass

    @abstractmethod
    def gradual_adapt(self, domains: list[Dataset]) -> float:
        pass

class Base(AdaptMethod):
    def __init__(self, model, config):
        super().__init__(model, config)
        
    def gradual_adapt(self, domains: list[Dataset]) -> float:
        for domain in domains[1:]:
            if self.cal_process:
                acc, loss = get_accuracy(domains[-1], self.model, self.device, self.ts_batch_size)
                comet.log_metrics({"Target Domain acc": acc, "Target Domain loss": loss})
                print(f"Target Domain - Acc: {acc:.4f} - Loss: {loss:.4f}")
            self.adapt(domain)
        acc, loss = get_accuracy(domains[-1], self.model, self.device, self.ts_batch_size)
        print(f"Final Acc: {acc:.4f}")
        if self.cal_process:
            comet.log_metrics({"Target Domain acc": acc, "Target Domain loss": loss})
        return acc
        
    def adapt(self, domain: Dataset):
        if self.cal_process:
            r_acc, r_loss = self.test(domain)
            print(f"Real-Labeled - Acc: {r_acc:.4f} - Loss: {r_loss:.4f}")
            comet.log_metrics({"Real-Labeled acc": r_acc,  "Real-Labeled loss": r_loss})

        
