import os
import time
import pickle
import copy
import random
import gc

from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import NaturalIdPartitioner
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import confusion_matrix

import numpy as np

from util import *

LR_DECAY = False


class TrainFEMNISTCluster(object):
    def __init__(self, config, device):
        self.config = config
        self.device = device

        # assert self.config['m'] % self.config['p'] == 0

    def setup(self):

        os.makedirs(self.config['project_dir'], exist_ok = True)

        self.result_fname = os.path.join(self.config['project_dir'], 'results.pickle')
        self.checkpoint_fname = os.path.join(self.config['project_dir'], 'checkpoint.pt')

        self.setup_datasets()
        self.setup_models()
        self.setup_adjacency()

        self.epoch = None
        self.lr = None
        #self.cluster_switch = None


    def setup_datasets(self):
        np.random.seed(self.config['data_seed'])
        self.dataset = {}

        X_train, y_train, train_indices, X_test, y_test, test_indices = self._setup_dataset()
        self.dataset['train'] = {'X': X_train, 'y': y_train, 'data_indices': train_indices} 
        self.dataset['test'] = {'X': X_test, 'y': y_test, 'data_indices': test_indices}

    def _setup_dataset(self):
        device = self.device
        cfg = self.config
        m = cfg['m']

        fedemnist = FederatedDataset(
            dataset="flwrlabs/femnist",
            partitioners={"train": NaturalIdPartitioner(partition_by="writer_id")}
        )

        # Get m partitions (clients)
        train_X_list, train_y_list, train_indices = [], [], {}
        test_X_list, test_y_list, test_indices = [], [], {}
        train_current_idx, test_current_idx = 0, 0

        # Get data for each writer
        for cid in range(m):
            partition = fedemnist.load_partition(partition_id=cid)
            images = partition["image"]
            labels = partition["character"]

            images = torch.stack([torch.tensor(np.array(img), dtype=torch.float32) for img in images]).unsqueeze(1) / 255.0
            images = images.view(-1, 1, 28, 28)
            labels = torch.tensor(labels, dtype=torch.long)

            # Shuffle indices
            num_samples = len(labels)
            indices = torch.randperm(num_samples)
            split = int(0.8 * num_samples)
            train_idx, test_idx = indices[:split], indices[split:]

            # Train split
            train_images = images[train_idx]
            train_labels = labels[train_idx]
            train_X_list.append(train_images)
            train_y_list.append(train_labels)
            train_indices[cid] = list(range(train_current_idx, train_current_idx + len(train_labels)))
            train_current_idx += len(train_labels)

            # Test split
            test_images = images[test_idx]
            test_labels = labels[test_idx]
            test_X_list.append(test_images)
            test_y_list.append(test_labels)
            test_indices[cid] = list(range(test_current_idx, test_current_idx + len(test_labels)))
            test_current_idx += len(test_labels)

        X_train = torch.cat(train_X_list, dim=0).to(device)
        y_train = torch.cat(train_y_list, dim=0).to(device)
        X_test = torch.cat(test_X_list, dim=0).to(device)
        y_test = torch.cat(test_y_list, dim=0).to(device)

        return X_train, y_train, train_indices, X_test, y_test, test_indices
   
   
    def setup_models(self):
        np.random.seed(self.config['train_seed'])
        torch.manual_seed(self.config['train_seed'])

        p = self.config['p']
        m = self.config['m']
        local_model_init = self.config['local_model_init']

        if local_model_init:
            self.models = [[SimpleCNN(h1 = self.config['h1']).to(self.device) for p_i in range(p)] for m_i in range(m)]

        else:
            global_models = [SimpleCNN(h1 = self.config['h1']).to(self.device) for p_i in range(p)] 
            self.models = [[copy.deepcopy(model) for model in global_models] for m_i in range(m)]  

        self.criterion = torch.nn.CrossEntropyLoss()

        # import ipdb; ipdb.set_trace()

    def setup_adjacency(self):
        graph = self.config['graph']
        Gp = self.config['Gp']
        m = self.config['m']
        if graph == 'er':
            adjacency = [[] for _ in range(m)]
            for i in range(m):
                nums = np.random.choice([0, 1], size=m, p=[1-Gp, Gp])
                for j in range(m):
                    if nums[j] == 1 and i != j:
                        adjacency[i].append(j)
            self.adjacency = adjacency
        elif graph == 'fixed':
            num_neighbors = max(1, int(Gp * (m - 1)))
            self.adjacency = []
            for m_i in range(m):
                neighbors = random.sample([j for j in range(m) if j != m_i], num_neighbors)
                self.adjacency.append(neighbors)
        else:   
            raise NotImplementedError("only er and fixed supported")


    def run(self):
        num_epochs = self.config['num_epochs']
        lr = self.config['lr']

        #self.cluster_switch = [[0 for _ in range(self.config['p'])] for m_i in range(self.config['m'])] 

        results = []

        # epoch -1
        self.epoch = -1

        result = {}
        result['epoch'] = -1

        t0 = time.time()
        res = self.test(train=True)
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['train'] = res

        self.print_epoch_stats(res)

        t0 = time.time()
        res = self.test(train=False)
        t1 = time.time()
        res['infer_time'] = t1-t0
        result['test'] = res
        self.print_epoch_stats(res)
        results.append(result)

        # this will be used in next epoch
        cluster_assign = result['train']['cluster_assign']

        for epoch in range(num_epochs):
            self.epoch = epoch

            result = {}
            result['epoch'] = epoch

            lr = self.lr_schedule(epoch)
            result['lr'] = lr

            t0 = time.time()
            result['train'] = self.train(cluster_assign, lr = lr)
            t1 = time.time()
            train_time = t1-t0

            t0 = time.time()
            res = self.test(train=True)
            t1 = time.time()
            res['infer_time'] = t1-t0
            res['train_time'] = train_time
            res['lr'] = lr
            result['train'] = res

            self.print_epoch_stats(res)

            t0 = time.time()
            res = self.test(train=False)
            t1 = time.time()
            res['infer_time'] = t1-t0
            result['test'] = res
            self.print_epoch_stats(res)

            results.append(result)

            # this will be used in next epoch's gradient update
            cluster_assign = result['train']['cluster_assign']

            if epoch % 10 == 0 or epoch == num_epochs - 1 :
                with open(self.result_fname, 'wb') as outfile:
                    pickle.dump(results, outfile)
                    print(f'result written at {self.result_fname}')
#                self.save_checkpoint()
                print(f'checkpoint written at {self.checkpoint_fname}')


        # plt.figure(figsize=(10,5))
        # plt.plot([r['train']['loss'] for r in results], label='train')
        # plt.xlabel('epoch')
        # plt.ylabel('loss')
        # plt.title('Training Loss per Epoch')
        # plt.legend()
        # plt.grid(True)
        # plt.savefig(os.path.join(self.config['project_dir'], 'train_loss.png'))
        # # import ipdb; ipdb.set_trace()

        # plt.figure(figsize=(10,5))
        # plt.plot([r['test']['acc'] for r in results], label='test')
        # plt.xlabel('epoch')
        # plt.ylabel('test accuracy')
        # plt.title('Test Accuracy per Epoch')
        # plt.legend()
        # plt.grid(True)
        # plt.savefig(os.path.join(self.config['project_dir'], 'test_acc.png'))

        # plt.figure(figsize=(10,5))
        # plt.plot([r['train']['cl_acc'] for r in results], label='train')
        # plt.xlabel('epoch')
        # plt.ylabel('cluster acc')
        # plt.title('Cluster Accuracy per Epoch')
        # plt.legend()
        # plt.grid(True)
        # plt.savefig(os.path.join(self.config['project_dir'], 'cluster_acc.png'))

        return results





    def lr_schedule(self, epoch):
        if self.lr is None:
            self.lr = self.config['lr']

        if epoch % 50 == 0 and epoch != 0 and LR_DECAY:
            self.lr = self.lr * 0.1

        return self.lr        


    def print_epoch_stats(self, res):
        if res['is_train']:
            data_str = 'tr'
        else:
            data_str = 'tst'

        if 'train_time' in res:
            time_str = f"{res['train_time']:.3f}sec(train) {res['infer_time']:.3f}sec(infer)"
        else:
            time_str = f"{res['infer_time']:.3f}sec"

        if 'lr' in res:
            lr_str = f" lr {res['lr']:4f}"
        else:
            lr_str = ""

        str0 = f"Epoch {self.epoch} {data_str}: l {res['loss']:.3f} a {res['acc']:.3f} clct{res['cl_ct']} {lr_str} {time_str}"

        print(str0)

    def train(self, cluster_assign, lr):
        VERBOSE = 0

        cfg = self.config
        m = cfg['m']
        p = cfg['p']
        tau = cfg['tau']

        # run local update
        t0 = time.time()


        for m_i in range(m):
            if VERBOSE and m_i % 100 == 0: print(f'm {m_i}/{m} processing \r', end ='')

            (X, y) = self.load_data(m_i)

            p_i = cluster_assign[m_i]
            model = self.models[m_i][p_i]

            # LOCAL UPDATE PER MACHINE tau times
            for step_i in range(tau):

                y_logit = model(X)
                loss = self.criterion(y_logit, y)

                model.zero_grad()
                loss.backward()
                self.local_param_update(model, lr)

            model.zero_grad()


        t02 = time.time()
        # print(f'running single ..took {t02-t01:.3f}sec')


        t1 = time.time()
        if VERBOSE: print(f'local update {t1-t0:.3f}sec')

        # apply gradient update
        t0 = time.time()

        if cfg['agg'] == 'batch':
            self.batch_aggregation(cluster_assign)
        elif cfg['agg'] == 'seq':
            self.sequential_aggregation(cluster_assign)
        else:
            raise NotImplementedError("only seq and batch supported")
        t1 = time.time()

        if VERBOSE: print(f'global update {t1-t0:.3f}sec')

    def check_local_model_loss(self, local_models):
        # for debugging
        m = self.config['m']

        losses = []
        for m_i in range(m):
            (X, y) = self.load_data(m_i)
            y_logit = local_models[m_i](X)
            loss = self.criterion(y_logit, y)

            losses.append(loss.item())

        return np.array(losses)
    
    def get_cluster_accuracy(self, actual, pred):
        # actual is the real cluster assignment, pred is the predicted cluster assignment
        # Computation of the confusion matrix for the hungarian algorithm
        cm = confusion_matrix(actual, pred)

        # Use the Hungarian algorithm to find the optimal assignment
        row_ind, col_ind = linear_sum_assignment(-cm)
        matching = dict(zip(col_ind, row_ind))

        remapped_preds = [matching[p] for p in pred]

        # Calculate the accuracy of the remapped predictions
        cl_acc = np.mean(np.array(remapped_preds) == np.array(actual))

        return cl_acc

    @torch.no_grad()
    def get_inference_stats(self, train = True):
        cfg = self.config
        if train:
            m = cfg['m']
            dataset = self.dataset['train']
        else:
            m = cfg['m']
            dataset = self.dataset['test']

        p = cfg['p']


        num_data = 0
        losses = {}
        corrects = {}
        for m_i in range(m):
            (X, y) = self.load_data(m_i, train=train) # load batch data rotated

            for p_i in range(p):
                y_logit = self.models[m_i][p_i](X)
                loss = self.criterion(y_logit, y) # loss of
                n_correct = self.n_correct(y_logit, y)

                # if torch.isnan(loss):
                #     print("nan loss: ", dataset['data_indices'][m_i])

                losses[(m_i,p_i)] = loss.item()
                corrects[(m_i,p_i)] = n_correct

            num_data += X.shape[0]

        # calculate loss and cluster the machines
        cluster_assign = []
        for m_i in range(m):
            machine_losses = [ losses[(m_i,p_i)] for p_i in range(p) ]
            #print("Machine Losses:", machine_losses)
            min_p_i = np.argmin(machine_losses)
            cluster_assign.append(min_p_i)

        # calculate optimal model's loss, acc over all models
        min_corrects = []
        min_losses = []
        for m_i, p_i in enumerate(cluster_assign):

            min_loss = losses[(m_i,p_i)]
            min_losses.append(min_loss)

            min_correct = corrects[(m_i,p_i)]
            min_corrects.append(min_correct)

        # print("losses: ", min_losses)

        loss = np.mean(min_losses)
        acc = np.sum(min_corrects) / num_data


        # check cluster assignment acc
        # cl_acc = self.get_cluster_accuracy(dataset['cluster_assign'], cluster_assign)
        cl_ct = [np.sum(np.array(cluster_assign) == p_i ) for p_i in range(p)]

        # improved cluster assignment acc (model 2 can work better on clients with p=3)
        

        res = {} # results
        # res['losses'] = losses
        # res['corrects'] = corrects
        res['cluster_assign'] = cluster_assign
        res['num_data'] = num_data
        res['loss'] = loss
        res['acc'] = acc
        # res['cl_acc'] = cl_acc
        res['cl_ct'] = cl_ct
        res['is_train'] = train

        # import ipdb; ipdb.set_trace()

        return res

    def n_correct(self, y_logit, y):
        _, predicted = torch.max(y_logit.data, 1)
        correct = (predicted == y).sum().item()

        return correct
    
    
    def load_data(self, m_i, train=True):
        dataset = self.dataset['train'] if train else self.dataset['test']
        indices = dataset['data_indices'][m_i]
        X_batch = dataset['X'][indices]
        y_batch = dataset['y'][indices]

        return X_batch, y_batch


    def local_param_update(self, model, lr):

        # gradient update manually

        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data -= lr * param.grad

        model.zero_grad()

        # import ipdb; ipdb.set_trace() # we need to check the output of name, check if duplicate exists

    def weighted_avg_update(self, i_model, j_model, r):
        for (name, param_i), (_, param_j) in zip(i_model.named_parameters(), j_model.named_parameters()):
            param_i.data = (r / (r + 1)) * param_i.data + (1 / (r + 1)) * param_j.data


    def sequential_aggregation(self, cluster_assign):
        num_clients = self.config['m']
        adjacency = self.adjacency
        for m_i in range(num_clients):
            m_i_neighbors = adjacency[m_i].copy()
            random.shuffle(m_i_neighbors)
            for r, m_j in enumerate(m_i_neighbors):
                p_j = cluster_assign[m_j] 
                self.weighted_avg_update(self.models[m_i][p_j], self.models[m_j][p_j], r+1)


    def _batch_aggregation(self, c_model, n_models):
        if n_models is None or len(n_models) == 0:
            return

        weights = {}

        for model in n_models:
            for name, param in model.named_parameters():
                if name not in weights:
                    weights[name] = torch.zeros_like(param.data)

                weights[name] += param.data

        for name, param in c_model.named_parameters():
            weights[name] /= len(n_models)
            param.data = weights[name]


    def batch_aggregation(self, cluster_assign):
        num_clients = self.config['m']
        p = self.config['p']
        adjacency = self.adjacency
        for m_i in range(num_clients):
            m_i_neighbors = adjacency[m_i]
            for p_i in range(p):
                c_model = self.models[m_i][p_i]
                n_models = []
                for m_j in m_i_neighbors:
                    if cluster_assign[m_j] == p_i:
                        n_models.append(self.models[m_j][p_i])
                self._batch_aggregation(c_model, n_models)



    def test(self, train=False):
        return self.get_inference_stats(train=train)
    

    def save_checkpoint(self):
        models_to_save = [model.state_dict() for model in self.models]
        torch.save({'models':models_to_save}, self.checkpoint_fname)


class SimpleCNN(nn.Module):
    def __init__(self, h1=512):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 7 * 7, h1)
        self.fc2 = nn.Linear(h1, 62)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.batchnorm1(x)
        x = self.pool(F.relu(self.conv2(x)))
        x = self.batchnorm2(x)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
