import torch
import torch.nn as nn
from DA_algorithms.DeepDA.transfer_losses import TransferLoss
from DA_algorithms.DeepDA import backbones
import os
import pickle
import pdb


class TransferNet(nn.Module):
    def __init__(self, num_class, base_net='resnet50', transfer_loss='mmd', use_bottleneck=True, bottleneck_width=256,
                 max_iter=1000, **kwargs):
        super(TransferNet, self).__init__()
        self.num_class = num_class
        self.base_network = backbones.get_backbone(base_net)
        self.use_bottleneck = use_bottleneck
        self.transfer_loss = transfer_loss
        if self.use_bottleneck:
            bottleneck_list = [
                nn.Linear(self.base_network.output_num(), bottleneck_width),
                nn.ReLU()
            ]
            self.bottleneck_layer = nn.Sequential(*bottleneck_list)
            feature_dim = bottleneck_width
        else:
            feature_dim = self.base_network.output_num()

        self.classifier_layer = nn.Linear(feature_dim, num_class)
        transfer_loss_args = {
            "loss_type": self.transfer_loss,
            "max_iter": max_iter,
            "num_class": num_class
        }
        self.adapt_loss = TransferLoss(**transfer_loss_args)
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, source, target, source_label):
        source = self.base_network(source)
        target = self.base_network(target)
        if self.use_bottleneck:
            source = self.bottleneck_layer(source)
            target = self.bottleneck_layer(target)
        # classification
        source_clf = self.classifier_layer(source)
        clf_loss = self.criterion(source_clf, source_label)
        # transfer
        kwargs = {}
        if self.transfer_loss == "lmmd":
            kwargs['source_label'] = source_label
            target_clf = self.classifier_layer(target)
            kwargs['target_logits'] = torch.nn.functional.softmax(target_clf, dim=1)
        elif self.transfer_loss == "daan":
            source_clf = self.classifier_layer(source)
            kwargs['source_logits'] = torch.nn.functional.softmax(source_clf, dim=1)
            target_clf = self.classifier_layer(target)
            kwargs['target_logits'] = torch.nn.functional.softmax(target_clf, dim=1)
        elif self.transfer_loss == 'bnm':
            tar_clf = self.classifier_layer(target)
            target = nn.Softmax(dim=1)(tar_clf)
        # pdb.set_trace()
        transfer_loss = self.adapt_loss(source, target, **kwargs)
        return clf_loss, transfer_loss

    def get_parameters(self, initial_lr=1.0):
        params = [
            {'params': self.base_network.parameters(), 'lr': 0.1 * initial_lr},
            {'params': self.classifier_layer.parameters(), 'lr': 1.0 * initial_lr},
        ]
        if self.use_bottleneck:
            params.append(
                {'params': self.bottleneck_layer.parameters(), 'lr': 1.0 * initial_lr}
            )
        # Loss-dependent
        if self.transfer_loss == "adv":
            params.append(
                {'params': self.adapt_loss.loss_func.domain_classifier.parameters(), 'lr': 1.0 * initial_lr}
            )
        elif self.transfer_loss == "daan":
            params.append(
                {'params': self.adapt_loss.loss_func.domain_classifier.parameters(), 'lr': 1.0 * initial_lr}
            )
            params.append(
                {'params': self.adapt_loss.loss_func.local_classifiers.parameters(), 'lr': 1.0 * initial_lr}
            )
        return params

    def predict(self, x):
        features = self.base_network(x)
        if self.use_bottleneck:
            x = self.bottleneck_layer(features)
            clf = self.classifier_layer(x)
        else:
            clf = self.classifier_layer(features)
        return clf

    def get_feature_info(self, x):
        features = self.base_network(x)
        x = self.bottleneck_layer(features)
        clf = self.classifier_layer(x)
        probs = nn.Softmax(dim=1)(clf)
        return x, probs

    def save_feature_info(self, loader, args, domain_name):
        iter_loader = iter(loader)
        len_loader = len(loader)
        n_batch = len_loader
        if n_batch == 0:
            n_batch = 50#args.n_iter_per_epoch
        self.epoch_based_processing(n_batch)
        start_test = True
        features_list = []
        probs_list = []
        labels_list = []
        for _ in range(n_batch):
            data, label = next(iter_loader)  # .next()
            data, label = data.to(args.device), label.to(args.device)

            bottleneck_output, probs_outputs = self.get_feature_info(data)

            features_list.append(bottleneck_output.detach().float().cpu())
            probs_list.append(probs_outputs.detach().float().cpu())
            labels_list.append(label.detach().float().cpu())

        all_features = torch.cat(features_list, dim=0)
        all_probs = torch.cat(probs_list, dim=0)
        all_label = torch.cat(labels_list, dim=0)
        # Save to a pickle
        save_path = os.path.join(args.feature_dir, f"{domain_name}_features.pkl")
        with open(save_path, 'wb') as f:
            pickle.dump({
                'features': all_features,
                'probs': all_probs,
                'labels': all_label
            }, f)

    def epoch_based_processing(self, *args, **kwargs):
        if self.transfer_loss == "daan":
            self.adapt_loss.loss_func.update_dynamic_factor(*args, **kwargs)
        else:
            pass