import random
from typing import List

import torch
from torchvision import transforms
from itertools import combinations
from torch.nn import functional as F

from utils.exp_configs import create_exp_config_setup
from clients.base import Client
from methods.local_training import base_train_function


def create_fedmuscle_setup(
    args,
    models_path,
    dataset_path_dict,
    device,
) -> List[Client]:
    # Add local training args
    num_local_epochs = 1

    train_args = {}

    return create_exp_config_setup(
        args=args,
        models_path=models_path,
        dataset_path_dict=dataset_path_dict,
        device=device,
        num_local_epochs=num_local_epochs,
        local_train_function=base_train_function,
        train_args=train_args,
    )

color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
data_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop((224, 224), antialias=True),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([color_jitter], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.Normalize(
            (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
        ),
    ]
)


def client_rep_gen_function(args, client, public_dataset, batch_indices, device):
    
    batch = [public_dataset[idx] for idx in batch_indices]
    

    if isinstance(batch[0], tuple):
        images = [item[0] for item in batch]
        images = torch.stack(images)
        texts = [item[1] for item in batch]
    else:
        images = torch.stack(batch)

    if "text" not in client.task:
        # CV client
        client.model.eval()
        with torch.no_grad():
            client.model.to(device)

            representations = client.model(
                data_transforms(images).to(device), projection=True
            ).logits.to(device)

            client.model.to("cpu")

    else:
        # NLP client
        client.model.eval()
        with torch.no_grad():
            client.model.to(device)

            # If text is augmented, pick one randomly
            if isinstance(texts[0], list):
                texts = [random.choice(captions) for captions in texts]

            text_encoding = client.tokenizer(texts, return_tensors='pt', max_length=client.max_length, padding='max_length', truncation=True)
            input_ids = text_encoding['input_ids'].to(device)
            attention_mask = text_encoding['attention_mask'].to(device)
            representations = client.model(input_ids=input_ids, attention_mask=attention_mask, projection=True)
            client.model.to("cpu")

    return representations

def aggregation_func(args, server_rec_rep, device):

    num_clients = len(server_rec_rep)

    agg_rep = {}
    alpha = {}

    for client_id in range(num_clients):

        rep_matrices = [server_rec_rep[client_id]]

        # The client with client_id is considered as the anchor
        noAnchor_clients = list(range(num_clients))
        noAnchor_clients.remove(client_id)

        if len(noAnchor_clients) > args.sel_rep_num:
            selected_clients = random.sample(noAnchor_clients, args.sel_rep_num)
        else:
            selected_clients = noAnchor_clients

        rep_matrices.extend(
            server_rec_rep[client_j] for client_j in selected_clients
        )

        M = len(rep_matrices)
        B, d = rep_matrices[0].shape

        # Compute cosine similarity among representations
        cos_sims_ij = {}
        for i, j in combinations(range(M), 2):
            rep_matrices_i = rep_matrices[i]
            rep_matrices_j = rep_matrices[j]
            cos_sims_ij[(i, j)] = F.cosine_similarity(
                rep_matrices_i[:, None, :],
                rep_matrices_j[None, :, :],
                dim=-1,
            )


        neg_expr_ij = torch.zeros(
            (B,) * M, device=device
        ) 

        gamma = 1/args.tau_prime - 1/args.tau

        for i, j in combinations(range(M), 2):

            if i != 0:
                expanded_cos_sim_ij = cos_sims_ij[(i, j)]

                for dim in range(M):
                    if dim != i and dim != j:
                        expanded_cos_sim_ij = expanded_cos_sim_ij.unsqueeze(dim)
                        

                neg_expr_ij += (-gamma) * expanded_cos_sim_ij


        agg_rep[client_id] = rep_matrices
        alpha[client_id] = neg_expr_ij

    return agg_rep, alpha


def client_rep_align_function(args, client, public_dataset, batch_indices, rep_matrices, neg_expr, device):

    batch = [public_dataset[idx] for idx in batch_indices]
    

    if isinstance(batch[0], tuple):
        images = [item[0] for item in batch]
        images = torch.stack(images)
        texts = [item[1] for item in batch]
    else:
        images = torch.stack(batch)


    client.local_optimizer.zero_grad()
    client.model.train()
    client.model.to(device)


        
        
    if "text" not in client.task:
        # CV client
        rep_matrices[0] = client.model(data_transforms(images).to(device), projection=True).logits.to(device)
    
    else:
        # NLP client

        # If text is augmented, pick one randomly
        if isinstance(texts[0], list):
            texts = [random.choice(captions) for captions in texts]


        text_encoding = client.tokenizer(texts, return_tensors='pt', max_length=client.max_length, padding='max_length', truncation=True)
        input_ids = text_encoding['input_ids'].to(device)
        attention_mask = text_encoding['attention_mask'].to(device)
        rep_matrices[0] = client.model(input_ids=input_ids, attention_mask=attention_mask, projection=True)


    M = len(rep_matrices)
    B, d = rep_matrices[0].shape

    # Cosine similarity with the anchor vector
    cos_sims_ij = {}
    for i, j in combinations(range(M), 2):
        if i == 0:
            rep_matrices_i = rep_matrices[i]
            rep_matrices_j = rep_matrices[j]
            cos_sims_ij[(i, j)] = F.cosine_similarity(
                rep_matrices_i[:, None, :],
                rep_matrices_j[None, :, :],
                dim=-1,
            )

            expanded_cos_sim_ij = cos_sims_ij[(i, j)]

            for dim in range(M):
                if dim != i and dim != j:
                    expanded_cos_sim_ij = expanded_cos_sim_ij.unsqueeze(dim)
                    

            neg_expr += (1/args.tau) * expanded_cos_sim_ij

    mask = torch.zeros((B,) * M, device=device)  
    mask.fill_diagonal_(1)


    # Compute mutualKT loss
    log_prob_ij = neg_expr - torch.log(
        (torch.exp(neg_expr)).sum(tuple(range(1, M)), keepdim=True)
    )
    mean_log_prob_pos_ij = (mask * log_prob_ij).sum(tuple(range(1, M)))
    loss = -mean_log_prob_pos_ij.mean()

    loss.backward()
    client.local_optimizer.step()

    print(loss)


    return {"loss": loss.cpu().detach().numpy()}

