import os
import torch
import torch.nn as nn
from network.network_digits import GeneratorDigits, LabelClassifierDigits, ExtractorClassifierDigits


class SolverDigitsTest:
    def __init__(self, da_method, target, target_test_dataloader):
        # basic settings
        self.target = target
        self.da_method = da_method

        # device setting
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # training setting
        self.t_test_dataloader = target_test_dataloader

        self.pretrained_save_dir = './../pretrained'

        # model setting
        self.G1 = GeneratorDigits().to(self.device)
        self.G2 = GeneratorDigits().to(self.device)
        self.LC1 = LabelClassifierDigits().to(self.device)
        self.LC2 = LabelClassifierDigits().to(self.device)
        self.LC_total = LabelClassifierDigits(input_size=2048*2).to(self.device)
        self.EC = ExtractorClassifierDigits().to(self.device)

        self.load_model()
        self.eval_mode()

    def load_model(self):
        fname = os.path.join(self.pretrained_save_dir, '{}_target_{}.pth'.format(self.da_method, self.target))
        if self.da_method == 'Multi-PL':
            checkpoint = torch.load(fname)
            self.G1.load_state_dict(checkpoint['G'])
            self.LC1.load_state_dict(checkpoint['LC'])
        elif self.da_method == 'Multi-EPL-2':
            checkpoint = torch.load(fname)
            self.G1.load_state_dict(checkpoint['G1'])
            self.G2.load_state_dict(checkpoint['G2'])
            self.LC_total.load_state_dict(checkpoint['LC_total'])

    def eval_mode(self):
        self.G1.eval()
        self.G2.eval()
        self.LC1.eval()
        self.LC2.eval()
        self.LC_total.eval()
        self.EC.eval()

    def test(self):
        correct = 0
        num_data = 0
        losses = 0
        criterion = nn.CrossEntropyLoss().to(self.device)
        if self.da_method == 'Multi-PL':
            for i, sample in enumerate(self.t_test_dataloader):
                image = sample['image'].to(self.device)
                label = sample['label'].to(self.device)

                feat = self.G1(image)
                logits = self.LC1(feat)
                loss = criterion(logits, label)
                pred = torch.argmax(logits, dim=-1)

                correct_check = (label == pred)
                correct += correct_check.sum().item()
                losses += loss.item()
                num_data += image.size(0)

        elif self.da_method == 'Multi-EPL-2':
            for i, sample in enumerate(self.t_test_dataloader):
                image = sample['image'].to(self.device)
                label = sample['label'].to(self.device)

                feat1 = self.G1(image)
                feat2 = self.G2(image)
                feat = torch.cat([feat1, feat2], dim=1)
                logits = self.LC_total(feat)
                loss = criterion(logits, label)
                pred = torch.argmax(logits, dim=-1)

                correct_check = (label == pred)
                correct += correct_check.sum().item()
                losses += loss.item()
                num_data += image.size(0)

        acc = 100 * correct / num_data
        loss = losses / num_data
        return acc, loss
