import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from Clients.losses import multi_view_contrastive_loss, multi_view_contrastive_loss_diverse
class BaseClient(object):
    def __init__(self, dataloader, modality_type, client_id=0):
        """
        Args:
            dataloader (torch.utils.data.DataLoader)
            modality_type (list[int]): e.g., [1, 0, 1] — use audio & text
            client_id (int)
            configs: contains learning_rate, contrastive, uni
        """
        super(BaseClient, self).__init__()
        self.dataloader = dataloader
        self.modality_type = modality_type
        self.client_id = client_id
        
    def extract_modal_inputs(self, data):
        # Assume data layout: [x_0, x_1, x_2, ..., labels]
        num_modalities = len(self.modality_type)
        inputs = [data[i].cuda() for i in range(num_modalities)]
        labels = data[-1].cuda()
        return inputs, labels
    
    def train(self, n_epochs=1):
        CE = nn.CrossEntropyLoss().cuda()
        self.model.cuda()
        self.model.train()
        for _ in range(n_epochs):
            for data in self.dataloader:
                inputs, labels = self.extract_modal_inputs(data)
                self.optimizer.zero_grad()
                outs, _, _ = self.model(*inputs)
                loss = 0.0
                for out in outs:
                    loss += CE(out, labels)
                loss.backward()
                self.optimizer.step()

    def test(self, dataloader):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for data in dataloader:
                inputs, labels = self.extract_modal_inputs(data)
                outs, _, _ = self.model(*inputs)
                pred_y = 0
                for out in outs:
                    pred_y += out
                _, predicted = torch.max(pred_y.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return correct / total
    
class ParseDL(nn.Module):
    def __init__(self, modality_type, encoder_fns, classifier_dim=64, output_dim=8, seed=42):
        """
        Args:
            modality_type: list of 0/1, e.g., [1, 0, 1]
            encoder_fns: list of encoder constructor functions, same length as modality_type
        """
        super(ParseDL, self).__init__()
        self.modality_type = modality_type
        self.num_modalities = len(modality_type)

        self.encoders = nn.ModuleList()
        self.classifiers = nn.ModuleList()
        self.projects = nn.ModuleList()
        for i in range(self.num_modalities):
            torch.manual_seed(seed)
            self.encoders.append(encoder_fns[i]())
            torch.manual_seed(seed)
            self.classifiers.append(nn.Linear(classifier_dim, output_dim))
            torch.manual_seed(seed)
            self.projects.append(nn.Linear(classifier_dim, classifier_dim))
        torch.manual_seed(seed)
        self.shared_classifier = nn.Linear(classifier_dim, output_dim)
        torch.manual_seed(seed)
        self.synergy_classifier = nn.Linear(classifier_dim, output_dim)

    def forward(self, *inputs):
        outs = []
        feat_c_list = []
        feat_m_list = []
        feat_s_list = []
        synergy_feats = []

        for i, use in enumerate(self.modality_type):
            if use:
                x = inputs[i]
                feat = self.encoders[i](x)
                feat_m, feat_c, feat_s = torch.chunk(feat, 3, dim=1)

                outs.append(self.classifiers[i](feat_m) + self.shared_classifier(feat_c))

                # Organize projections by type
                feat_c_list.append(self.projects[i](feat_c))
                feat_m_list.append(self.projects[i](feat_m))
                feat_s_list.append(self.projects[i](feat_s))

                synergy_feats.append(feat_s)

        if len(outs) == 1:
            return outs, None, outs[0]

        feat_synergy = torch.stack(synergy_feats).mean(dim=0)
        out_s = sum(outs) + self.synergy_classifier(feat_synergy)

        # Concatenate projection groups (c first, then m, then s)
        feats = feat_c_list + feat_m_list + feat_s_list
        embeddings = [F.normalize(f, dim=1) for f in feats]

        return outs, embeddings, out_s

class ParseClient(BaseClient):
    def __init__(self, dataloader, modality_type, client_id=0, encoder_fns=None, classifier_dim=64, output_dim=8, configs=None):
        """
        Args:
            dataloader (torch.utils.data.DataLoader)
            modality_type (list[int]): e.g., [1, 0, 1] — use audio & text
            client_id (int)
            configs: contains learning_rate, contrastive, uni
        """
        super().__init__(dataloader, modality_type, client_id)
        self.model = ParseDL(modality_type, encoder_fns, classifier_dim, output_dim, configs.seed)
        if configs.dataset == "cap":
            torch.optim.Adam(self.model.parameters(), lr=configs.learning_rate, betas=(0.9, 0.999))
        else:
            self.optimizer = optim.SGD(self.model.parameters(),
                                    lr=configs.learning_rate,
                                    momentum=0.9,
                                    weight_decay=5e-4)
        self.contrastive = configs.contrastive



    def train(self, n_epochs=1):
        CE = nn.CrossEntropyLoss().cuda()
        self.model.cuda()
        self.model.train()
        for _ in range(n_epochs):
            for data in self.dataloader:
                inputs, labels = self.extract_modal_inputs(data)
                self.optimizer.zero_grad()

                _, feats, out_s = self.model(*inputs)
                loss = CE(out_s, labels)
                loss += self.contrastive * multi_view_contrastive_loss_diverse(feats, sum(self.modality_type))
                loss.backward()
                self.optimizer.step()

    def test(self, dataloader):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for data in dataloader:
                inputs, labels = self.extract_modal_inputs(data)
                _, _, out_s = self.model(*inputs)
                _, predicted = torch.max(out_s.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return correct / total
    
    
class ModalityDL(nn.Module):
    def __init__(self, modality_type, encoder_fns, classifier_dim=64, output_dim=8, seed=42):
        """
        Args:
            modality_type: list of 0/1, e.g., [1, 0, 1]
            encoder_fns: list of encoder constructor functions, same length as modality_type
        """
        super(ModalityDL, self).__init__()
        self.modality_type = modality_type
        self.num_modalities = len(modality_type)

        self.encoders = nn.ModuleList()
        self.classifiers = nn.ModuleList()

        for i in range(self.num_modalities):
            torch.manual_seed(seed)
            self.encoders.append(encoder_fns[i]())
            torch.manual_seed(seed)
            self.classifiers.append(nn.Linear(classifier_dim, output_dim))

    def forward(self, *inputs):
        outs = []

        for i, use in enumerate(self.modality_type):
            if use:
                x = inputs[i]
                feat = self.encoders[i](x)
                outs.append(self.classifiers[i](feat))
        return outs, None, None
    
    
class ModalityClient(BaseClient):
    def __init__(self, dataloader, modality_type, client_id=0, encoder_fns=None, classifier_dim=64, output_dim=8, configs=None):
        """
        Args:
            dataloader (torch.utils.data.DataLoader)
            modality_type (list[int]): e.g., [1, 0, 1] — use audio & text
            client_id (int)
            configs: contains learning_rate, contrastive, uni
        """
        super().__init__(dataloader, modality_type, client_id)
        self.model = ModalityDL(modality_type, encoder_fns, classifier_dim, output_dim, configs.seed)
        if configs.dataset == "cap":
            torch.optim.Adam(self.model.parameters(), lr=configs.learning_rate, betas=(0.9, 0.999))
        else:
            self.optimizer = optim.SGD(self.model.parameters(),
                                    lr=configs.learning_rate,
                                    momentum=0.9,
                                    weight_decay=5e-4)

    
class TaskDL(nn.Module):
    def __init__(self, modality_type, encoder_fns, classifier_dim=64, output_dim=8, seed=42):
        """
        Args:
            modality_type: list of 0/1, e.g., [1, 0, 1]
            encoder_fns: list of encoder constructor functions, same length as modality_type
        """
        super(TaskDL, self).__init__()
        self.modality_type = modality_type
        self.num_modalities = len(modality_type)

        self.encoders = nn.ModuleList()
        
        
        for i in range(self.num_modalities):
            torch.manual_seed(seed)
            self.encoders.append(encoder_fns[i]())
        torch.manual_seed(seed)
        self.classifier = nn.Linear(classifier_dim, output_dim)
    def forward(self, *inputs):
        outs = []
        feat = 0.0
        for i, use in enumerate(self.modality_type):
            if use:
                x = inputs[i]
                feat += self.encoders[i](x)
        outs.append(self.classifier(feat/sum(self.modality_type)))
        return outs, None, None
    
    
class TaskClient(BaseClient):
    def __init__(self, dataloader, modality_type, client_id=0, encoder_fns=None, classifier_dim=64, output_dim=8, configs=None):
        """
        Args:
            dataloader (torch.utils.data.DataLoader)
            modality_type (list[int]): e.g., [1, 0, 1] — use audio & text
            client_id (int)
            configs: contains learning_rate, contrastive, uni
        """
        super().__init__(dataloader, modality_type, client_id)
        self.model = TaskDL(modality_type, encoder_fns, classifier_dim, output_dim, configs.seed)
        if configs.dataset == "cap":
            torch.optim.Adam(self.model.parameters(), lr=configs.learning_rate, betas=(0.9, 0.999))
        else:
            self.optimizer = optim.SGD(self.model.parameters(),
                                    lr=configs.learning_rate,
                                    momentum=0.9,
                                    weight_decay=5e-4)
    
    
class HybridClient(BaseClient):
    def __init__(self, dataloader, modality_type, client_id=0, encoder_fns=None, classifier_dim=64, output_dim=8, configs=None):
        """
        Args:
            dataloader (torch.utils.data.DataLoader)
            modality_type (list[int]): e.g., [1, 0, 1] — use audio & text
            client_id (int)
            configs: contains learning_rate, contrastive, uni
        """
        super().__init__(dataloader, modality_type, client_id)
        self.model = TaskDL(modality_type, encoder_fns, classifier_dim, output_dim, configs.seed)
        if configs.dataset == "cap":
            torch.optim.Adam(self.model.parameters(), lr=configs.learning_rate, betas=(0.9, 0.999))
        else:
            self.optimizer = optim.SGD(self.model.parameters(),
                                    lr=configs.learning_rate,
                                    momentum=0.9,
                                    weight_decay=5e-4)
        
        
        