from torch import nn
import torch
import random
import pytorch_lightning as pl
import timm
import torchvision
import os
import random
import numpy as np
import copy
import torch.optim as optim
from lightly.loss import NegativeCosineSimilarity
import logging
from utilities import load_dataset
from coopt.method import *
from coopt.server_client.data_utils import (
    random_fourier_transform,
    random_matrix,
    random_project,
    )



class OneHotEncoder(nn.Module):
    def __init__(self, num_classes):
        super(OneHotEncoder, self).__init__()
        self.num_classes = num_classes

    def forward(self, class_indices):
        one_hot = torch.zeros(class_indices.size(0), self.num_classes, device=class_indices.device)

        one_hot[torch.arange(class_indices.size(0), device=class_indices.device), class_indices] = 1
        return one_hot




def extract_features_and_labels(model, dataloader, device):
    features_list, labels_list = [], []
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            features = model(images).view(images.size(0), -1)
            features_list.append(features.cpu())
            labels_list.append(labels)
    features = torch.cat(features_list, dim=0)
    labels = torch.cat(labels_list, dim=0)
    return features, labels


def uniform_loss(x, t=2):
    x = x / x.norm(p=2, dim=1, keepdim=True)
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()


def compute_model_quality(dataset, all_local_model_types):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=500)
    all_uniform_loss = []
    for model_type in all_local_model_types:
        model_path = f"save/model/{model_type}_cifar10_resnet18.pt"
        if not os.path.exists(model_path):
            print(f"Model path {model_path} does not exist.")
            continue

        model = torch.load(model_path)
        model = model.to(device)
        model.eval()

        features, _ = extract_features_and_labels(model, dataloader, device)
        features_tensor = torch.tensor(features, dtype=torch.float32)
        uniform_metric = uniform_loss(features_tensor)
        all_uniform_loss.append(uniform_metric.item())
        print("All_uniform_value =", all_uniform_loss)
    ## Negative correlation
    return all_uniform_loss

class redata:
    def __init__(self, dataset, re_indices):
        self.re_indices = re_indices
        self._dataset = dataset

    def __len__(self):
        return len(self.re_indices)

    def __getitem__(self, idx):
        idx = self.re_indices[idx]
        data = self._dataset[idx]
        return data

def optimize_data(
    prior_model: nn.Module, 
    feature_dim,
    max_feature_dim, # to judge if client_feature_dim == max_feature_dim
    resolution, 
    x: torch.Tensor,
    dim_up=False,
    align_W=None, align_b=None
    ) -> torch.Tensor:

    size = resolution

    transform = torchvision.transforms.Compose(
        [
            # torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize(size),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )
    
    dataset, indeces = x

    dataset.transform = transform

    dataset_ = redata(dataset, indeces)

    data_loader = torch.utils.data.DataLoader(
        dataset_,
        batch_size=256,
        shuffle=False,
        num_workers=8,
    )

    if isinstance(prior_model, (nn.Embedding, OneHotEncoder)):
        print("it's human!")
        with torch.no_grad():
            Y = []
            prior_model = prior_model.eval().cuda()
            
            for _, y in data_loader:
                y = y.cuda(non_blocking=True)
                y = prior_model(y).float().detach()
                Y.append(y)

            Y = torch.cat(Y, dim=0)
    else:   
        print(type(prior_model))
        with torch.no_grad():
            Y = []
            prior_model = prior_model.eval().cuda()
            
            for x, _ in data_loader:
                x = x.cuda(non_blocking=True)
                y = prior_model(x).float().detach()
                Y.append(y)

            Y = torch.cat(Y, dim=0)
        # print(Y.shape)

    dataset.transform = torchvision.transforms.ToTensor()

    if feature_dim < max_feature_dim and dim_up:
        if align_W is None and align_b is None: # 1st stage 
            Y, W, b = random_matrix(Y, feature_dim, max_feature_dim)
            uniform_value = uniform_loss(Y).item()
            return Y.detach().cpu(), uniform_value, W, b
        
        elif align_W is not None and align_b is not None: # 2nd stage
            Y = random_matrix(Y, feature_dim, max_feature_dim, align_W, align_b)
            return Y.detach().cpu()

    elif feature_dim == max_feature_dim or not dim_up:
        # print("yes")
        if align_W is None and align_b is None:
            W = 0
            b = 0
            uniform_value = uniform_loss(Y).item()
            # print(uniform_value)
            return Y.detach().cpu(), uniform_value, W, b # info print need to be improved
        
        elif align_W is not None and align_b is not None:
            return Y.detach().cpu()
    



def align(
    client_data: torch.Tensor, 
    client_model: nn.Module, 
    client_feature_dim, 
    client_resolution,
    align_feature_dim,
    align_data,
    align_features,
    dim_up=False,
    align_W=None, align_b=None
    ) -> torch.Tensor:


    size = client_resolution

    transform = torchvision.transforms.Compose(
        [
            # torchvision.transforms.ToPILImage(),
            torchvision.transforms.Resize(size),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )

    dataset, indeces = align_data

    dataset.transform = transform

    dataset_ = redata(dataset, indeces)
    
    print("len of data_allocate: ", len(dataset_))

    data_loader = torch.utils.data.DataLoader(
        dataset_,
        batch_size=256,
        shuffle=False,
        num_workers=8,
    )
    if isinstance(client_model, (nn.Embedding, OneHotEncoder)):
        print("it's human when align!")
        with torch.no_grad():
            client_align_Y = []
            client_model = client_model.eval().cuda()
            
            for _, y in data_loader:
                y = y.cuda(non_blocking=True)
                client_align_y = client_model(y).float().detach()
                client_align_Y.append(client_align_y)

            client_align_Y = torch.cat(client_align_Y, dim=0)
        pass
    else:
        with torch.no_grad():
            client_align_Y = []
            client_model = client_model.eval().cuda()
            
            for x, _ in data_loader:
                x = x.cuda(non_blocking=True)
                client_align_y = client_model(x).float().detach()
                client_align_Y.append(client_align_y)

            client_align_Y = torch.cat(client_align_Y, dim=0)
    print("uniform-value before dim up: ", uniform_loss(client_align_Y), client_align_Y.shape)
    
    client_data = client_data.to('cuda')
    align_features = align_features.to('cuda')

    if client_feature_dim < align_feature_dim and dim_up:
        client_align_Y = random_matrix(client_align_Y, client_feature_dim, align_feature_dim, align_W, align_b)
    
        print("uniform-value after dim up: ", uniform_loss(client_align_Y), client_align_Y.shape)
        
        Y = c_matrix(
            client_data, # client_data_feature need to be improved 
            client_align_Y, align_features, # train c-matrix
            align_feature_dim, align_feature_dim # dim of c-matrix
            )
        
    elif client_feature_dim == align_feature_dim or not dim_up:

        print("uniform-value without dim up: ", uniform_loss(client_align_Y), client_align_Y.shape)

        Y = c_matrix(
            client_data, # client_data_feature need to be improved 
            client_align_Y, align_features, # train c-matrix
            client_feature_dim, align_feature_dim # dim of c-matrix
            )
        
    dataset.transform = torchvision.transforms.ToTensor()

    return Y.detach().cpu()


def c_matrix(
        client_data, 
        client_align_Y, align_features, 
        client_feature_dim, align_feature_dim):
    
    C_matrix = nn.Linear(client_feature_dim, align_feature_dim).to("cuda")

    optimizer = optim.AdamW(C_matrix.parameters(), lr=0.001, weight_decay=0.01)
    criterion = NegativeCosineSimilarity().to("cuda")

    iterations = 8192
    index_all = np.arange(client_align_Y.shape[0])
    
    for iteration in range(iterations):
        
        running_loss = 0.0
        total_samples = 0

        # batches = client_align_Y.split(100)
        # batches_t = align_features.split(100)
        
        # for inputs, inputs_t in zip(batches, batches_t):
        
        index_128 = np.random.choice(index_all, size=128, replace=False)
        
        inputs = client_align_Y[index_128]
        inputs_t = align_features[index_128]
        
        outputs = C_matrix(inputs)
        loss = criterion(outputs, inputs_t)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        total_samples += inputs.size(0)

        epoch_loss = running_loss / total_samples
        # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
    with torch.no_grad():
        client_data = C_matrix(client_data)
    
    return client_data



def optimize_model(
    args,
    x: torch.Tensor,
    optimal_data: torch.Tensor,
    y: torch.Tensor,
    backbone: nn.Module,
    model_type,
):

    METHOD, TRANSFORM = METHODS[args.method]

    val_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(32),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )

    # if model_type == 'Global':
    backbone = torch.load(f"save/model/random_cifar10_resnet18.pt")

    model = METHOD(backbone, args)

    transform = TRANSFORM(args.dataset, args.input_size)

    trainset = Opt_Data(copy.deepcopy(x), optimal_data, y, transform)

    model.trainset = trainset

    valset = load_dataset("CIFAR10", transform=val_transform, train=False)


    train_dataloader = torch.utils.data.DataLoader(
        model.trainset,
        batch_size=args.batch_size // len(args.devices) // args.grad_accu,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers,
    )

    val_dataloader = torch.utils.data.DataLoader(
        valset,
        batch_size=args.batch_size // len(args.devices) // args.grad_accu,
        shuffle=False,
        drop_last=False,
        num_workers=args.num_workers,
    )

    logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

    if args.distributed == False:
        trainer = pl.Trainer(
            max_epochs=args.epochs,
            devices=args.devices,
            accelerator="gpu",
            logger=False,
            enable_checkpointing=False,
            # callbacks=[BackboneCheckpointCallback(args.model_path)],
            accumulate_grad_batches=args.grad_accu,
            enable_progress_bar=False,
        )
    else:
        trainer = pl.Trainer(
            max_epochs=args.epochs,
            devices=args.devices,
            accelerator="gpu",
            # strategy="ddp",
            strategy="ddp_find_unused_parameters_true",
            precision=16,
            sync_batchnorm=True,
            use_distributed_sampler=True,
            logger=False,
            enable_checkpointing=False,
            # callbacks=[BackboneCheckpointCallback(args.model_path)],
            accumulate_grad_batches=args.grad_accu,
        )

    trainer.fit(
        model=model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
    )

    print(f"{model_type}_Training_Acc: {model.val_acc_list}")

    best_val_acc = model.best_acc
    return copy.deepcopy(model.best_backbone), best_val_acc
