import torch
from torch.utils.data import DataLoader, Subset, Dataset
from argparse import ArgumentParser, Namespace
from pathlib import Path
import pickle
import json
import sys
import numpy as np
import os
from copy import deepcopy
import time 
PROJECT_DIR = Path(__file__).parent.parent.absolute()

sys.path.append(PROJECT_DIR.as_posix())
sys.path.append(PROJECT_DIR.joinpath("datapreprocess").as_posix())


from datapreprocess.models import IndexNet_MLP, IndexNet_Transformer
import torch.nn.functional as F

class FeatureSet(Dataset):
    def __init__(self, image_features, label_features):
        self.image_features = image_features
        self.label_features = label_features

    def __len__(self):
        return len(self.image_features)

    def __getitem__(self, idx):
        return self.image_features[idx], self.label_features[idx]
    
    def union(self, other, max_addition=128):
        max_addition = min(max_addition, len(other))
        return FeatureSet(
            torch.concat((self.image_features, other.image_features[:max_addition])),
            torch.concat((self.label_features, other.label_features[:max_addition])),
        )
    
    def shuffle(self):
        perm = torch.randperm(len(self))
        return FeatureSet(self.image_features[perm], self.label_features[perm])


def get_trainer_argparser() -> ArgumentParser:
    parser = ArgumentParser()
    parser.add_argument(
        "-d",
        "--dataset",
        type=str,
        choices=[
            "mnist",
            "cifar10",
            "cifar100",
            "synthetic",
            "femnist",
            "emnist",
            "fmnist",
            "celeba",
            "medmnistS",
            "medmnistA",
            "medmnistC",
            "covid19",
            "svhn",
            "usps",
            "tiny_imagenet",
            "cinic10",
            "domain",
            "shakespeare"
        ],
        default="domain",
    )
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("-lr", "--local_lr", type=float, default=1e-1)
    parser.add_argument("-m", "--momentum", type=float, default=0.9)
    parser.add_argument("-e", "--epochs", type=int, default=5)
    parser.add_argument("-b", "--batch_size", type=int, default=128)
    parser.add_argument("-de", "--device", type=int, default=0)
    parser.add_argument("-s", "--save_dir", type=str, default="./indexs")
    parser.add_argument("-wd", "--weight_decay", type=float, default=5e-5)
    parser.add_argument("-tp", "--train_type", type=str, default='global')
    parser.add_argument("-id", "--trail_id", type=str, default='0')
    parser.add_argument("-max", "--client_max", type=int, default=-1)
    parser.add_argument("-mp", "--model_type", type=str, default='transformer')
    parser.add_argument("-div", "--diverse", type=int, default=1)
    return parser


class IndexTrainer:
    def __init__(self, clients_id):

        self.clients_id = clients_id

        self.args = get_trainer_argparser().parse_args()
        if self.args.model_type == 'mlp':
            self.model = IndexNet_MLP()
        elif self.args.model_type == 'transformer':
            self.model = IndexNet_Transformer()
        else:
            raise ValueError('Unknown model type')

        self.optimizer = torch.optim.Adam(
            params=self.model.parameters(),
            lr=self.args.local_lr,
            betas=(0.9, 0.99),
            weight_decay=self.args.weight_decay
        )

        # self.optimizer = torch.optim.SGD(
        #     params=self.model.parameters(),
        #     lr=self.args.local_lr,
        #     momentum=0.9,
        #     weight_decay=self.args.weight_decay
        # )

        self.schedular = None

        self.schedular = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=self.optimizer, T_max=self.args.epochs)

        # self.schedular = torch.optim.lr_scheduler.StepLR(optimizer=self.optimizer, step_size=10, gamma=0.99)

        self.device = torch.device(self.args.device if torch.cuda.is_available() else "cpu")

        self.initilize_data()
    
    def cos_sim_loss(self, x, y, tau=0.5):
        """
        cosine similarity
        """
        sims = 1 - F.cosine_similarity(x, y)
        return torch.mean(sims)
    
    def reconstruction_loss(self, x, y):
        """
        reconstruction error
        """
        return F.mse_loss(x, y)
        # return torch.mean(1 - F.cosine_similarity(x, y))
    
    def orthogonal_loss(self, x, y):
        """
        orthogonal loss
        """
        return torch.mean(torch.abs(torch.bmm(x.unsqueeze(1), y.unsqueeze(1).transpose(1, 2))))
    
    def contrastive_loss(self, x, tau=0.5):
        x_norm = x / torch.norm(x,dim=1,keepdim=True)
        sim_matrix = torch.matmul(x_norm, x_norm.T)
        mask = torch.eye(x.shape[0], dtype=torch.bool).to(self.device)
        sim_matrix.masked_fill(mask, -1e10)
        # print(x.shape, x_norm.shape, sim_matrix.shape)

        sims = torch.sum(torch.exp(sim_matrix / tau), dim=1)
        return torch.mean(torch.log(sims))
    
    def initilize_data(self):
        self.datasets = []
        self.dataloaders = []
        self.global_data_set = FeatureSet(torch.tensor([]), torch.tensor([]))
        for client_id in self.clients_id:
            try:
                extracted_path = PROJECT_DIR / "datapreprocess" / "features" / self.args.dataset / "extracted-{}.pkl".format(client_id)
                with open(extracted_path, "rb") as f:
                    extracted = pickle.load(f)
            except:
                raise FileNotFoundError(f"Please partition {self.args.dataset} first.")
            if "image_features" in extracted:
                image_features = extracted["image_features"]
            else:
                image_features = extracted["data_features"]
            label_features = extracted["label_features"]
            dataset = FeatureSet(image_features, label_features)
            dataset = dataset.shuffle()
            if self.args.client_max > 0:
                dataloader = DataLoader(dataset[:self.args.client_max], self.args.batch_size)
            else:
                dataloader = DataLoader(dataset, self.args.batch_size)
            self.datasets.append(dataset)
            self.dataloaders.append(dataloader)
            self.global_data_set = self.global_data_set.union(dataset)
        self.global_data_set = self.global_data_set.shuffle()
        self.global_dataloader = DataLoader(self.global_data_set, self.args.batch_size)

    def train_epoch(self, train_loader, t, train_model = None, train_optimizer = None, train_schedular = None):
        model = train_model if train_model else self.model
        optimizer = train_optimizer if train_optimizer else self.optimizer
        schedular = train_schedular if train_schedular else self.schedular
        model.train()
        model.to(self.device)

        metrics = {
            'orthogonal_loss': [],
            'reconstruction_loss': [],
            'classification_loss': [],
            'contrastive_loss': [],
            'total_loss': []
        }

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)
            optimizer.zero_grad()
            f_1, f_2= model(data)

            orthogonal_loss = self.orthogonal_loss(f_1, f_2)
            classification_loss = self.cos_sim_loss(f_1, target)
            if self.args.diverse:
                contrastive_loss = self.contrastive_loss(f_2)
            else:
                contrastive_loss = torch.tensor(0.0).to(self.device)

            f = model.reconstruction_forward(f_1, f_2)
            reconstruction_loss = self.reconstruction_loss(f, data)
            loss = orthogonal_loss + 10 * classification_loss + reconstruction_loss + 0.1 * contrastive_loss
            loss.backward()
            optimizer.step()
            if schedular:
                schedular.step()

            metrics['orthogonal_loss'].append(orthogonal_loss.item())
            metrics['reconstruction_loss'].append(reconstruction_loss.item())
            metrics['classification_loss'].append(classification_loss.item())
            metrics['contrastive_loss'].append(contrastive_loss.item())
            metrics['total_loss'].append(loss.item())

        model.to('cpu')

        return metrics

        

    def train(self):
        for epoch in range(self.args.epochs):
            metrics = {
            'orthogonal_loss': [],
            'reconstruction_loss': [],
            'classification_loss': [],
            'contrastive_loss': [],
            'total_loss': []
            }
            for dataloader in self.dataloaders:
                current_metric = self.train_epoch(dataloader, epoch)
                metrics = {k: metrics[k] + current_metric[k] for k in metrics}

            print('rount {} --> oth-loss: {}, cls-loss: {}, rec-loss: {}, con-loss: {}, total: {}'.format(epoch, np.mean(metrics["orthogonal_loss"]), np.mean(metrics["classification_loss"]), np.mean(metrics["reconstruction_loss"]), np.mean(metrics["contrastive_loss"]), np.mean(metrics["total_loss"])))

        return self.save_indexs(self.dataloaders)
    
    def global_train(self):
        self.metrics = []
        for epoch in range(self.args.epochs):
            metrics = {
            'orthogonal_loss': [],
            'reconstruction_loss': [],
            'classification_loss': [],
            'contrastive_loss': [],
            'total_loss': []
            }
            current_metric = self.train_epoch(self.global_dataloader, epoch)
            metrics = {k: metrics[k] + current_metric[k] for k in metrics}

            print('rount {} --> oth-loss: {}, cls-loss: {}, rec-loss: {}, con-loss: {}, total: {}'.format(epoch, np.mean(metrics["orthogonal_loss"]), np.mean(metrics["classification_loss"]), np.mean(metrics["reconstruction_loss"]), np.mean(metrics["contrastive_loss"]), np.mean(metrics["total_loss"])))

            self.metrics.append(metrics)


        return self.save_indexs(self.dataloaders)
    
    def federated_train(self):
        self.metrics = []
        if self.args.model_type == 'mlp':
            client_models = [IndexNet_MLP() for _ in range(len(self.clients_id))]
        elif self.args.model_type == 'transformer':
            client_models = [IndexNet_Transformer() for _ in range(len(self.clients_id))]
        else:
            raise ValueError('Unknown model type')
        client_optimizers = [torch.optim.Adam(
            params=client_models[i].parameters(),
            lr=self.args.local_lr,
            betas=(0.9, 0.99),
            weight_decay=self.args.weight_decay
        ) for i in range(len(self.clients_id))]
        client_schedulers = [torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=client_optimizers[i], T_max=self.args.epochs) for i in range(len(self.clients_id))]
        
        # client_schedulers = [torch.optim.lr_scheduler.StepLR(optimizer=client_optimizers[i], step_size=10, gamma=0.99) for i in range(len(self.clients_id))]
        for epoch in range(self.args.epochs):
            metrics = {
            'orthogonal_loss': [],
            'reconstruction_loss': [],
            'classification_loss': [],
            'contrastive_loss': [],
            'total_loss': []
            }
            sampled_clients = np.random.choice(self.clients_id, len(self.clients_id) // 10, replace=False)
            # for i, client_model in enumerate(client_models):
            for i in sampled_clients:
                client_model = client_models[i]
                client_model.load_state_dict(deepcopy(self.model.state_dict()))
                for k in range(10):
                    current_metric = self.train_epoch(self.dataloaders[i], epoch, client_model, client_optimizers[i], client_schedulers[i])
                metrics = {k: metrics[k] + current_metric[k] for k in metrics}
            
            print('rount {} --> oth-loss: {}, cls-loss: {}, rec-loss: {}, con-loss: {}, total: {}'.format(epoch, np.mean(metrics["orthogonal_loss"]), np.mean(metrics["classification_loss"]), np.mean(metrics["reconstruction_loss"]), np.mean(metrics["contrastive_loss"]), np.mean(metrics["total_loss"])))

            self.metrics.append(metrics)

            aggregated_dict = {k: torch.zeros_like(v) for k, v in self.model.state_dict().items()}
            total_len = 0
            for i in sampled_clients:
                aggregated_dict = {k: aggregated_dict[k] + v * len(self.dataloaders[i]) for k, v in client_models[i].state_dict().items()}
                total_len += len(self.dataloaders[i])
            aggregated_dict = {k: v / total_len for k, v in aggregated_dict.items()}
            self.model.load_state_dict(aggregated_dict)
        return self.save_indexs(self.dataloaders)

    @torch.no_grad()
    def save_indexs(self, train_loaders):
        client_indexs = {}
        for i, client_id in enumerate(self.clients_id):
            train_loader = train_loaders[i]
            casual_features = torch.tensor([])
            sample_indexs = torch.tensor([])

            self.model.eval()
            self.model.to(self.device)
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(self.device), target.to(self.device)
                f_1, f_2= self.model(data)
                casual_features = torch.cat((casual_features, f_1.cpu()), 0)
                sample_indexs = torch.cat((sample_indexs, torch.cat((f_2.cpu(), target.cpu()), 1)), 0)
            
            client_index = torch.mean(sample_indexs, 0).squeeze()

            index = {
                "client_index": client_index,
                "casual_features": casual_features,
                "sample_indexs": sample_indexs
            }

            # dataset_root = './indexs/{}/{}/{}/{}'.format(self.args.dataset, self.args.train_type, self.args.model_type, self.args.trail_id)
            dataset_root = './indexs/{}/{}/{}/e{}_lr{}_div_{}'.format(self.args.dataset, self.args.train_type, self.args.model_type,self.args.epochs, self.args.local_lr, self.args.diverse)
            
            if not os.path.isdir(dataset_root):
                os.makedirs(dataset_root)
            with open(dataset_root + "/index-{}.pkl".format(client_id), "wb") as f:
                pickle.dump(index, f)

            client_indexs[client_id] = client_index
            print(client_index)
        


        return client_indexs






if __name__ == "__main__":
    client_num = 100
    index_trainder = IndexTrainer([i for i in range(client_num)])
    # dataset_root = './indexs/{}/{}/{}/{}'.format(index_trainder.args.dataset, index_trainder.args.train_type, index_trainder.args.model_type, index_trainder.args.trail_id)
    dataset_root = './indexs/{}/{}/{}/e{}_lr{}_div_{}'.format(index_trainder.args.dataset, index_trainder.args.train_type, index_trainder.args.model_type,index_trainder.args.epochs, index_trainder.args.local_lr, index_trainder.args.diverse)

    # save args
    args_dict=index_trainder.args.__dict__
    args_path=dataset_root + "/args.txt"
    os.makedirs(os.path.dirname(args_path),exist_ok=True)
    with open(args_path, mode="w") as f:
        for arg,value in args_dict.items():
            f.writelines(arg+":"+str(value)+"\n")
    # train
    time_begin=time.time()
    if index_trainder.args.train_type == 'global':
        client_indexs = index_trainder.global_train()
    elif index_trainder.args.train_type == 'federated':
        client_indexs = index_trainder.federated_train()
    else:
        raise ValueError("train_type must be global or federated")
    time_end=time.time()
    print("training time:",(time_end-time_begin),"s\t each epoch takes",(time_end-time_begin)/index_trainder.args.epochs,"s")



    with open(dataset_root + "/index-summary.pkl", "wb") as f:
        pickle.dump(client_indexs, f)
    
    with open(dataset_root + "/train_metrics.pkl", "wb") as f:
        pickle.dump(index_trainder.metrics, f)
        


    print(client_indexs)
    