import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import copy
from tllib.alignment.dann import DomainAdversarialLoss
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.modules.grl import WarmStartGradientReverseLayer

def get_activation(activation):
    if activation.lower() == 'gelu':
        return nn.GELU()
    elif activation.lower() == 'rrelu':
        return nn.RReLU(inplace=True)
    elif activation.lower() == 'selu':
        return nn.SELU(inplace=True)
    elif activation.lower() == 'silu':
        return nn.SiLU(inplace=True)
    elif activation.lower() == 'hardswish':
        return nn.Hardswish(inplace=True)
    elif activation.lower() == 'leakyrelu':
        return nn.LeakyReLU(inplace=True)
    else:
        return nn.ReLU(inplace=True)


class ResMLP(nn.Module):
    def __init__(
        self, channel=512, res_expansion=1.0, bias=True, activation='relu'):
        super().__init__()
        self.act = get_activation(activation)
        self.net1 = nn.Sequential(
            nn.Linear(channel, int(channel * res_expansion), bias=bias),
            nn.BatchNorm1d(int(channel * res_expansion)),
            self.act
        )

        self.net2 = nn.Sequential(
            nn.Linear(int(channel * res_expansion), channel, bias=bias),
            nn.BatchNorm1d(channel)
        )

    def forward(self, x):
        return self.act(self.net2(self.net1(x)) + x)


class MLP(nn.Module):
    def __init__(
        self, channel=512, res_expansion=1.0, bias=True, activation='relu'):
        super().__init__()
        self.act = get_activation(activation)
        self.net1 = nn.Sequential(
            nn.Linear(channel, int(channel * res_expansion), bias=bias),
            nn.BatchNorm1d(int(channel * res_expansion)),
            self.act
        )

        self.net2 = nn.Sequential(
            nn.Linear(int(channel * res_expansion), channel, bias=bias),
            nn.BatchNorm1d(channel)
        )

    def forward(self, x):
        return self.act(self.net2(self.net1(x)))



def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class Model(nn.Module):
    def __init__(self, cfg):
        super(Model, self).__init__()
        self.layer_num = cfg.mlp_num
        self.cfg = cfg
        if cfg.resMLP:
            print("Residule Connected")
            self.CLIP_proj = get_clones(ResMLP(res_expansion=cfg.res_expansion), cfg.mlp_num)
            self.CLAP_proj = get_clones(ResMLP(res_expansion=cfg.res_expansion), cfg.mlp_num)
        else:
            print("No Residule Connection")
            self.CLIP_proj = get_clones(MLP(res_expansion=cfg.res_expansion), cfg.mlp_num)
            self.CLAP_proj = get_clones(MLP(res_expansion=cfg.res_expansion), cfg.mlp_num)

        self.init_weights()

    def forward(self, CLIP_embs, CLAP_embs):
        for i in range(self.layer_num):
            CLIP_embs = self.CLIP_proj[i](CLIP_embs)
            CLAP_embs = self.CLAP_proj[i](CLAP_embs)
        return F.normalize(CLIP_embs, dim=-1), F.normalize(CLAP_embs, dim=-1)

    def init_weights(self):
        # initialize transformer
        for m in self.CLIP_proj.parameters():
            if m.dim() > 1:
                nn.init.xavier_uniform_(m)
        for m in self.CLAP_proj.parameters():
            if m.dim() > 1:
                nn.init.xavier_uniform_(m)