import torch
import copy
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
from tqdm import tqdm

from .base import AdaptMethod
from .utils import get_labeled_data
from utils import get_accuracy, comet

class GST(AdaptMethod):
    def __init__(self, model):
        super().__init__(model)
        self.dr = self.cf.gst.discard_ratio
        self.bs = self.cf.gst.batch_size
        self.lr = self.cf.gst.adam.learning_rate
        self.wd = self.cf.gst.adam.weight_decay
        self.epochs = self.cf.gst.epochs
        
        self.criterion = torch.nn.CrossEntropyLoss()        
                
    def gradual_adapt(self, domains: list[Dataset]) -> float:
        for idx, domain in enumerate(domains):
            if idx == 0:
                continue
            if self.cal_process:
                acc, loss = get_accuracy(domains[-1], self.model, self.device, self.ts_batch_size)
                comet.log_metrics({"Target Domain acc": acc, "Target Domain loss": loss})
                print(f"Target Domain - Acc: {acc:.4f} - Loss: {loss:.4f}")
            print(f"┌────────── Adapt {idx} ─────────┐")
            self.adapt(domain)
            print(f"└───────────────────────────────┘")
        acc, loss = get_accuracy(domains[-1], self.model, self.device, self.ts_batch_size)
        print(f"Final Acc: {acc:.4f}")
        if self.cal_process:
            comet.log_metrics({"Target Domain acc": acc, "Target Domain loss": loss})
        return acc
        
    def adapt(self, domain: Dataset):
        labeled_data = get_labeled_data(domain, self.model, self.device, self.dr, self.bs)
        loader = DataLoader(labeled_data, batch_size=self.bs)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd)
        for epoch in range(self.epochs):
            self.model.train()
            for data, label in tqdm(loader, desc=f"Epoch {epoch+1}/{self.epochs}", leave=False):
                data = data.to(self.device)
                label = label.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, label)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            if self.cal_process:
                # f_acc, f_loss = self.test(labeled_data)
                r_acc, r_loss = self.test(domain)
                # print(f"Fake-Labeled Epoch {epoch+1}/{self.epochs} - Acc: {f_acc:.4f} - Loss: {f_loss:.4f}", end=" ")
                print(f"Real-Labeled Epoch {epoch+1}/{self.epochs} - Acc: {r_acc:.4f} - Loss: {r_loss:.4f}")
                comet.log_metrics({"Real-Labeled acc": r_acc,  "Real-Labeled loss": r_loss})

# ---------------

if __name__ == "__main__":
    from config import get_config
    from data import get_source_domains
    from model import get_trained_model
    cf = get_config(["config/adapt.yaml"])
    cf.print()
    
    _, tr = get_source_domains("rotate_mnist")
    model = get_trained_model("rotate_mnist", "cnn").to(cf.device)
    gst = GST(model, cf)
    print(tr)
    gst.adapt(tr)