# Copyright (c) Kakao Brain. All Rights Reserved.

import torch
import torch.nn as nn
import torch.nn.functional as F

from domainbed.ur_networks import URFeaturizer
from domainbed.lib import misc
from domainbed.algorithms import Algorithm
import sys

from domainbed.lib.misc import random_pairs_of_minibatches
from sklearn.neighbors import NearestNeighbors
from torchvision import transforms as T
import numpy as np
#torch.autograd.set_detect_anomaly(True)



def get_optimizer(name, params, **kwargs):
    name = name.lower()
    optimizers = {"adam": torch.optim.Adam, "sgd": torch.optim.SGD, "adamw": torch.optim.AdamW}
    optim_cls = optimizers[name]

    return optim_cls(params, **kwargs)


class ForwardModel(nn.Module):
    """Forward model is used to reduce gpu memory usage of SWAD.
    """
    def __init__(self, network):
        super().__init__()
        self.network = network

    def forward(self, x):
        return self.predict(x)

    def predict(self, x):
        return self.network(x)


class MeanEncoder(nn.Module):
    """Identity function"""
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x


class VarianceEncoder(nn.Module):
    """Bias-only model with diagonal covariance"""
    def __init__(self, shape, init=0.1, channelwise=True, eps=1e-5):
        super().__init__()
        self.shape = shape
        self.eps = eps

        init = (torch.as_tensor(init - eps).exp() - 1.0).log()
        b_shape = shape
        if channelwise:
            if len(shape) == 4:
                # [B, C, H, W]
                b_shape = (1, shape[1], 1, 1)
            elif len(shape ) == 3:
                # CLIP-ViT: [H*W+1, B, C]
                b_shape = (1, 1, shape[2])
            else:
                raise ValueError()

        self.b = nn.Parameter(torch.full(b_shape, init))

    def forward(self, x):
        return F.softplus(self.b) + self.eps


def get_shapes(model, input_shape):
    # get shape of intermediate features
    with torch.no_grad():
        dummy = torch.rand(1, *input_shape).to(next(model.parameters()).device)
        _, feats = model(dummy, ret_feats=True)
        shapes = [f.shape for f in feats]

    return shapes


class MIRO(Algorithm):
    """Mutual-Information Regularization with Oracle"""
    def __init__(self, input_shape, num_classes, num_domains, hparams, **kwargs):
        super().__init__(input_shape, num_classes, num_domains, hparams)
        self.pre_featurizer = URFeaturizer(
            input_shape, self.hparams, freeze="all", feat_layers=hparams.feat_layers
        )
        self.featurizer = URFeaturizer(
            input_shape, self.hparams, feat_layers=hparams.feat_layers
        )
        self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes)
        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.ld = hparams.ld
        self.input_shape = input_shape

        # build mean/var encoders
        shapes = get_shapes(self.pre_featurizer, self.input_shape)
        self.mean_encoders = nn.ModuleList([
            MeanEncoder(shape) for shape in shapes
        ])
        self.var_encoders = nn.ModuleList([
            VarianceEncoder(shape) for shape in shapes
        ])

        # optimizer
        parameters = [
            {"params": self.network.parameters()},
            {"params": self.mean_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult},
            {"params": self.var_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult},
        ]
        self.optimizer = get_optimizer(
            hparams["optimizer"],
            parameters,
            lr=self.hparams["lr"],
            weight_decay=self.hparams["weight_decay"],
        )
        
        #train_len_dict = {'Caltech101': 1415, 'LabelMe': 2656, 'SUN09': 3282, 'VOC2007': 3376}
        #train_len_dict = {'C_Train': 36360, 'C_Task1': 13065, 'C_Task2': 16395, 'C_Task3': 10075}
        #train_len_dict = {'fashion-MNIST': 30000, 'clothing1M_noisy': 579524}
        #train_len_dict = {'clipart': 48833, 'infograph': 53201, 'painting': 75759, 'quickdraw': 172500, 'real': 175327, 'sketch': 70386}
        #self.train_criterion = [elr_loss(train_len_dict[d], num_classes=num_classes, beta=0.3, args_lambda=1.0) for d in train_len_dict]

    def update(self,minibatches, unlabeled=None):
        all_x = torch.cat([x for x, y in minibatches])
        all_y = torch.cat([y for x, y in minibatches])
        feat, inter_feats = self.featurizer(all_x, ret_feats=True)
        #print("feat shape!!!", feat.shape)
        #print("inter_feats shape!!!", len(inter_feats))
        logit = self.classifier(feat)
        loss = F.cross_entropy(logit, all_y)
        #elr loss
        # loss = 0
        # #elr_loss_list = []
        # #miro_loss_list = []
        # for i in range(len(key)):
        #     feat_i, _ = self.featurizer(x[i], ret_feats=True)
        #     logit_i = self.classifier(feat_i)
        #     elr_loss_mean = self.train_criterion[kwargs["env_list"][i]](key[i], logit_i, y[i])
        #     #elr_loss_mean, raw_elr_loss = self.train_criterion[kwargs["env_list"][i]](key[i], logit_i, y[i])
        #     #print("elr shape!!!!", raw_elr_loss.shape)
        #     #elr_loss_list.append(raw_elr_loss)
        #     loss += elr_loss_mean
        #     #loss += self.train_criterion[kwargs["env_list"][i]](key[i], logit_i, y[i])

        # all_x = torch.cat(x)
        # all_y = torch.cat(y)
        # feat, inter_feats = self.featurizer(all_x, ret_feats=True)

        # MIRO
        with torch.no_grad():
            _, pre_feats = self.pre_featurizer(all_x, ret_feats=True)
            #print("pre_feats shape!!!", len(pre_feats))

        reg_loss = 0.
        for f, pre_f, mean_enc, var_enc in misc.zip_strict(
            inter_feats, pre_feats, self.mean_encoders, self.var_encoders
        ):
            # mutual information regularization
            mean = mean_enc(f)
            var = var_enc(f)
            # print("f shape!!!!!", f.shape)
            # print("pre_f shape!!!!!", pre_f.shape)
            vlb = (mean - pre_f).pow(2).div(var) + var.log()
            vlb_mean = torch.mean(vlb, dim=(1, 2, 3))
            #print("vlb shape!!!!", vlb_mean.shape)
            #miro_loss_list.append(vlb_mean)
            reg_loss += vlb.mean() / 2.

        loss += reg_loss * self.ld

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # print("loss", loss)
        # print("reg_loss", reg_loss)

        return {"loss": loss.item()}
        #return {"loss": loss.item(), "reg_loss": reg_loss.item()}
        #return {"loss": loss.item(), "reg_loss": reg_loss.item(), "elr_loss": elr_loss_list, "miro_loss":miro_loss_list}

    def predict(self, x):
        return self.network(x)

    def get_forward_model(self):
        forward_model = ForwardModel(self.network)
        return forward_model

class MIRO_ELR(Algorithm):
    """Mutual-Information Regularization with Oracle"""
    def __init__(self, input_shape, num_classes, num_domains, hparams, **kwargs):
        super().__init__(input_shape, num_classes, num_domains, hparams)
        self.pre_featurizer = URFeaturizer(
            input_shape, self.hparams, freeze="all", feat_layers=hparams.feat_layers
        )
        self.featurizer = URFeaturizer(
            input_shape, self.hparams, feat_layers=hparams.feat_layers
        )
        self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes)
        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.ld = hparams.ld

        # build mean/var encoders
        shapes = get_shapes(self.pre_featurizer, self.input_shape)
        self.mean_encoders = nn.ModuleList([
            MeanEncoder(shape) for shape in shapes
        ])
        self.var_encoders = nn.ModuleList([
            VarianceEncoder(shape) for shape in shapes
        ])

        # optimizer
        parameters = [
            {"params": self.network.parameters()},
            {"params": self.mean_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult},
            {"params": self.var_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult},
        ]
        self.optimizer = get_optimizer(
            hparams["optimizer"],
            parameters,
            lr=self.hparams["lr"],
            weight_decay=self.hparams["weight_decay"],
        )
        
        self.train_len_dict = self.hparams["train_len_dict"]
        #train_len_dict = {'Caltech101': 1415, 'LabelMe': 2656, 'SUN09': 3282, 'VOC2007': 3376}
        #train_len_dict = {'C_Train': 36360, 'C_Task1': 13065, 'C_Task2': 16395, 'C_Task3': 10075}
        #train_len_dict = {'fashion-MNIST': 30000, 'clothing1M_noisy': 579524}
        #train_len_dict = {'clipart': 48833, 'infograph': 53201, 'painting': 75759, 'quickdraw': 172500, 'real': 175327, 'sketch': 70386}
        self.train_criterion = [elr_loss(train_len_dict[d], num_classes=num_classes, beta=0.3, args_lambda=1.0) for d in self.train_len_dict]

    def update(self, x, y, key, **kwargs):
        all_x = torch.cat(x)
        all_y = torch.cat(y)
        feat, inter_feats = self.featurizer(all_x, ret_feats=True)
        logit = self.classifier(feat)
        loss = F.cross_entropy(logit, all_y)
        #elr loss
        for i in range(len(key)):
            feat_i, _ = self.featurizer(x[i], ret_feats=True)
            logit_i = self.classifier(feat_i)
            elr_loss_mean = self.train_criterion[kwargs["env_list"][i]](key[i], logit_i, y[i])
            #elr_loss_mean, raw_elr_loss = self.train_criterion[kwargs["env_list"][i]](key[i], logit_i, y[i])
            #print("elr shape!!!!", raw_elr_loss.shape)
            #elr_loss_list.append(raw_elr_loss)
            loss += elr_loss_mean
            #loss += self.train_criterion[kwargs["env_list"][i]](key[i], logit_i, y[i])

        # MIRO
        with torch.no_grad():
            _, pre_feats = self.pre_featurizer(all_x, ret_feats=True)
            #print("pre_feats shape!!!", len(pre_feats))

        reg_loss = 0.
        for f, pre_f, mean_enc, var_enc in misc.zip_strict(
            inter_feats, pre_feats, self.mean_encoders, self.var_encoders
        ):
            # mutual information regularization
            mean = mean_enc(f)
            var = var_enc(f)
            vlb = (mean - pre_f).pow(2).div(var) + var.log()
            vlb_mean = torch.mean(vlb, dim=(1, 2, 3))
            reg_loss += vlb.mean() / 2.

        loss += reg_loss * self.ld

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # print("loss", loss)
        # print("reg_loss", reg_loss)

        return {"loss": loss.item()}
        #return {"loss": loss.item(), "reg_loss": reg_loss.item()}
        #return {"loss": loss.item(), "reg_loss": reg_loss.item(), "elr_loss": elr_loss_list, "miro_loss":miro_loss_list}

    def predict(self, x):
        return self.network(x)

    def get_forward_model(self):
        forward_model = ForwardModel(self.network)
        return forward_model

class MIRO_LSL(Algorithm):
    """Mutual-Information Regularization with Oracle"""
    def __init__(self, input_shape, num_classes, num_domains, hparams, **kwargs):
        super().__init__(input_shape, num_classes, num_domains, hparams)
        self.num_classes = num_classes
        self.pre_featurizer = URFeaturizer(
            input_shape, self.hparams, freeze="all", feat_layers=hparams.feat_layers
        )
        self.featurizer = URFeaturizer(
            input_shape, self.hparams, feat_layers=hparams.feat_layers
        )
        self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes)
        self.network = nn.Sequential(self.featurizer, self.classifier)
        self.ld = hparams.ld

        # build mean/var encoders
        shapes = get_shapes(self.pre_featurizer, self.input_shape)
        self.mean_encoders = nn.ModuleList([
            MeanEncoder(shape) for shape in shapes
        ])
        self.var_encoders = nn.ModuleList([
            VarianceEncoder(shape) for shape in shapes
        ])

        # optimizer
        parameters = [
            {"params": self.network.parameters()},
            {"params": self.mean_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult},
            {"params": self.var_encoders.parameters(), "lr": hparams.lr * hparams.lr_mult},
        ]
        self.optimizer = get_optimizer(
            hparams["optimizer"],
            parameters,
            lr=self.hparams["lr"],
            weight_decay=self.hparams["weight_decay"],
        )

        self.strong_aug = T.Compose(
            [T.RandomHorizontalFlip(p=0.3),
            T.RandomVerticalFlip(p=0.3),
            T.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=10),
            T.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0))]
        )
        

    def update(self, x, y, key, **kwargs):
        all_x = torch.cat(x)
        all_y = torch.cat(y)
        feat, inter_feats = self.featurizer(all_x, ret_feats=True)
        logit = self.classifier(feat)
        loss = F.cross_entropy(logit, all_y)

        # MIRO
        with torch.no_grad():
            _, pre_feats = self.pre_featurizer(all_x, ret_feats=True)

        reg_loss = 0.
        for f, pre_f, mean_enc, var_enc in misc.zip_strict(
            inter_feats, pre_feats, self.mean_encoders, self.var_encoders
        ):
            # mutual information regularization
            mean = mean_enc(f)
            var = var_enc(f)
            vlb = (mean - pre_f).pow(2).div(var) + var.log()
            vlb_mean = torch.mean(vlb, dim=(1, 2, 3))
            reg_loss += vlb.mean() / 2.

        loss += reg_loss * self.ld
        #print("reg loss", loss)

        #LSL
        all_strong_x = self.strong_aug(all_x)
        relabel_y = self.sample_relabel(all_x, all_y)
        select_x_idx = self.knn_sample_selection(feat, relabel_y)
        structural_y = self.extract_structural_labels(all_x, relabel_y)
        loss_ce = self.mixup(all_strong_x[select_x_idx], relabel_y[select_x_idx])
        loss_fc = -F.cosine_similarity(feat, self.featurizer(all_strong_x), dim=-1).sum()
        loss_st = self.mixup(all_strong_x, structural_y)
        # print("ce loss", loss_ce)
        # print("fc loss", loss_fc)
        # print("st loss", loss_st)
        loss += loss_ce + self.hparams['lambda_fc']*loss_fc + self.hparams['lambda_st']*loss_st

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {"loss": loss.item()}

    def predict(self, x):
        return self.network(x)

    def get_forward_model(self):
        forward_model = ForwardModel(self.network)
        return forward_model
    
    def extract_structural_labels(self, all_x, all_y):
        kst = self.hparams['kst']
        N = len(all_x)  # Number of samples
    
        # Initialize edge count table (N x K)
        T = torch.zeros(N, self.num_classes)

        # Extract and normalize features using the encoder
        Feats = self.featurizer(all_x)  # Features from the model encoder (N x d)
        Feats = Feats / Feats.norm(dim=1, keepdim=True)  # Normalize feature vectors

        # Iterate over all samples to calculate reverse k-NN
        for i in range(N):
            # Compute cosine similarity between the i-th sample and all other samples
            fi = Feats[i].unsqueeze(0)  # (1 x d)
            scores = F.linear(fi, Feats)  # Cosine similarity (1 x N)
            
            # Get top-k indices of nearest neighbors (excluding itself)
            _, indices = scores.squeeze().topk(kst + 1)  # (kst + 1), includes itself
            indices = indices[1:]  # Remove itself
            
            # Emit k edges to kst nearest neighbors, propagate labels
            T[indices, all_y[i]] += 1
        
        # Normalize T row-wise to get the structural labels Yst
        Yst = T / (T.sum(dim=1, keepdim=True)+1e-10)
        
        return Yst
    
    def sample_relabel(self, all_x, all_y):
        with torch.no_grad():
            scores = self.predict(all_x)
        max_scores, predicted_labels = torch.max(scores, dim=1)
        threshold = self.hparams['relabel_t']
        # Separate samples based on the threshold condition
        mask = max_scores > threshold  # mask for samples above the threshold
        relabel_y = all_y.clone()
        relabel_y[mask] = predicted_labels[mask]
        return relabel_y
    
    def knn_sample_selection(self, all_x, all_y):
        X_np = all_x.detach().cpu().numpy()
        Y_np = all_y.detach().cpu().numpy()
        k = self.hparams['sample_selection_k']
        # Fit k-NN model
        knn = NearestNeighbors(n_neighbors=k)
        knn.fit(X_np)

        # Find the k-nearest neighbors for each sample
        neighbors = knn.kneighbors(X_np, return_distance=False)  # Get the indices of neighbors

        # List to store the indices of selected samples
        selected_indices = []

        # Iterate over each sample and its neighbors
        for i in range(X_np.shape[0]):
            neighbor_labels = Y_np[neighbors[i]]  # Get the labels of the k-nearest neighbors

            # Find the majority label among the neighbors
            majority_label = np.bincount(neighbor_labels).argmax()

            # If the current sample's label matches the majority label, select the sample
            if Y_np[i] == majority_label:
                selected_indices.append(i)
        
        return selected_indices
    
    def mixup(self, x, y):
        alpha = self.hparams['mixup_alpha']
        lam = np.random.beta(alpha, alpha)
        loss = 0
        minibatches = []
        device = x.device
        n = x.size(0)  # Number of samples
        batch_size = self.hparams['batch_size']
        for i in range(0, n, batch_size):
            batch_x = x[i:i + batch_size].to(device)
            batch_y = y[i:i + batch_size].to(device)
            minibatches.append((batch_x, batch_y))
        for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches):
            mix_x = lam * xi + (1 - lam) * xj
            predictions = self.predict(mix_x)
            # print("predict", predictions)
            # print("yi",yi)
            loss += lam * F.cross_entropy(predictions, yi, reduction='sum')
            loss += (1 - lam) * F.cross_entropy(predictions, yj, reduction='sum')
        return loss





