import os.path
import sys

import numpy as np
import torch
import torch.nn as nn
import wandb
import yaml
import random
from tqdm.auto import tqdm
from model import VFL,MVFL,DeepMVFL
from model_arxiv import DeepMVFL_Uncon
from functools import partial
from model import MessageMean

class Solver(object):
    ''' Training and testing IL models'''

    def __init__(self,
                 train_loader,
                 val_loader,
                 test_loader,
                 config):

        # config
        self.config = config
        self.device = config.device
        self.seed = config.seed

        # data
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.n_class = config.n_class

        # training
        self.epochs = config.epochs
        self.lr = config.lr

        # model
        self.n_device = config.n_device
        self.model_setup = config.model_setup
        self.repeat = config.repeat
        self.gossip_mode = config.gossip_mode
        self.n_gossip = config.n_gossip
        self.image_shape = config.image_shape
        self.d_inter_init = config.d_inter_init
        self.d_inter = config.d_inter
        self.activation = config.activation

        # setup
        self.drop_mode = config.drop_mode
        self.graph_type = config.graph_type
        self.drop_rate_train = config.drop_rate_train
        self.rgg_radius = config.rgg_radius
        self.k_mvfl = config.k_mvfl
        if self.model_setup == 'VFL':
            self.get_batch_edge_index = partial(self.get_batch_edge_index_VFL, drop_mode=self.drop_mode)
            #self.test = partial(self.test_VFL, drop_mode=self.drop_mode)
            self.test = self.test_VFL
        elif self.model_setup == 'MVFL':
            self.get_batch_edge_index = partial(self.get_batch_edge_index_MVFL, drop_mode=self.drop_mode)
            #self.test = partial(self.test_MVFL, drop_mode=self.drop_mode)
            self.test = self.test_MVFL
        elif self.model_setup == 'DeepMVFL' or self.model_setup == 'DeepMVFL_unconstrain':
            self.get_batch_edge_index = partial(self.get_batch_edge_index_MVFL, drop_mode=self.drop_mode)
            #self.test = partial(self.test_MVFL, drop_mode=self.drop_mode)
            self.test = self.test_MVFL
        else:
            raise ValueError('model_setup not recognized')

        # log
        self.wandb = config.wandb
        self.save = config.save
        self.save_dir = config.save_dir
        self.gossip_passing = MessageMean() #Added to account for gossip during inference and not training


        self.build_model()

    def build_model(self):

        if self.model_setup == 'VFL':
            self.model = VFL(self.n_device,
                             d_inter_init=self.d_inter_init,
                             d_inter=self.d_inter,
                             image_shape=self.image_shape,
                             activation=self.activation)
            # choose a fixed server for each experiment
            self.server_idx = torch.randint(0, self.n_device, (1,)).item()
        elif self.model_setup == 'MVFL':
            self.model = MVFL(self.n_device,
                              repeat=self.repeat,
                              d_inter_init=self.d_inter_init,
                              d_inter=self.d_inter,
                              gossip_mode=self.gossip_mode,
                              n_gossip=self.n_gossip,
                              image_shape=self.image_shape,
                              activation=self.activation)

            self.server_idx,_ =  torch.sort(torch.randperm(self.n_device)[: self.k_mvfl])
        elif self.model_setup == 'DeepMVFL':
            self.model = DeepMVFL(self.n_device,
                                  repeat=self.repeat,
                                  d_inter_init=self.d_inter_init,
                                  d_inter=self.d_inter,
                                  gossip_mode=self.gossip_mode,
                                  n_gossip=self.n_gossip,
                                  image_shape=self.image_shape,
                                  activation=self.activation)
        elif self.model_setup == 'DeepMVFL_unconstrain':
            self.model = DeepMVFL_Uncon(self.n_device,
                                  repeat=self.repeat,
                                  d_inter_init=self.d_inter_init,
                                  d_inter=self.d_inter,
                                  gossip_mode=self.gossip_mode,
                                  n_gossip=self.n_gossip,
                                  image_shape=self.image_shape,
                                  activation=self.activation)
        else:
            raise ValueError('model_setup not recognized')

        self.opt = torch.optim.Adam(self.model.parameters(), self.lr)
        self.model.to(self.device)

    def reset_grad(self):
        self.opt.zero_grad()
    def set_model_mode(self, set_to_train=True):
        self.model.train(set_to_train)

    def device_dist(self,sender_idx, receiver_idx):
        if self.graph_type == 'uni':
            return 1
        elif self.graph_type == 'grid':
            n_patch_per_dim = int(np.sqrt(self.n_device))
            sender = torch.Tensor([sender_idx // n_patch_per_dim, sender_idx % n_patch_per_dim])
            receiver = torch.Tensor([receiver_idx // n_patch_per_dim, receiver_idx % n_patch_per_dim])
            dist = torch.norm(sender - receiver)
            if dist > 1:
                return float('inf')
            else:
                return dist
        elif self.graph_type == 'rgg':
            n_patch_per_dim = int(np.sqrt(self.n_device))
            sender = torch.Tensor([sender_idx // n_patch_per_dim, sender_idx % n_patch_per_dim])
            receiver = torch.Tensor([receiver_idx // n_patch_per_dim, receiver_idx % n_patch_per_dim])
            dist = torch.norm(sender - receiver)
            if dist > self.rgg_radius:
                return float('inf')
            else:
                return 1
        elif self.graph_type == 'grid_soft':
            raise NotImplementedError
        else:
            raise ValueError('graph_type not recognized')
    def get_batch_edge_index_VFL(self, drop_mode, drop_rate, n_device, bs):
        # get edge_index for one sample
        edge_index = []
        n_sender = n_device

        if drop_mode == 'comm':
            for i in range(n_sender):
                # SimpleConv expects [sender, receiver]
                # i - sender
                # j - receiver
                if i == self.server_idx or torch.rand(1) > drop_rate * self.device_dist(i,self.server_idx):
                    edge_index.append([i, self.server_idx])
        elif drop_mode == 'device':
            for i in range(n_sender):
                # SimpleConv expects [sender, receiver]
                # i - sender
                # j - receiver
                if i == self.server_idx or torch.rand(1) > drop_rate * self.device_dist(i,self.server_idx):
                    edge_index.append([i, self.server_idx])
        else:
            raise ValueError('drop_mode not recognized')
        edge_index = torch.Tensor(edge_index).T
        # repeat for  all samples in the batch
        edge_index_batch = []
        for i in range(bs):
            edge_index_batch.append(edge_index.clone() + n_device*i)
        edge_index = torch.cat(edge_index_batch, dim=1).long()
        return edge_index

    def get_batch_edge_index_MVFL(self, drop_mode, drop_rate, n_device, bs, return_device_valid=False):
        # get edge_index for one sample
        # TODO: add a check on self.server_idx to accomodate MVFL
        if len(self.server_idx) == 0:
            self.server_idx = torch.tensor(list(range(n_device)))
        edge_index = []
        if drop_mode == 'comm':
            #n_receiver, n_sender = n_device, n_device
            n_receiver, n_sender = len(self.server_idx), n_device #KMVFL
            #device_valid = list(range(n_device))
            device_valid = (self.server_idx).tolist() #KMVFL

            #for j in range(n_receiver):
            for j in self.server_idx: #KMVFL
                for i in range(n_sender):
                    # SimpleConv expects [sender, receiver]
                    # i - sender
                    # j - receiver
                    # always keep self-loop for MVFL
                    if i == j or torch.rand(1) > drop_rate * self.device_dist(i,j):
                        # if device dist is infinity and drop rate 0, the device will be dropped (comparing number with nan with always return False)
                        edge_index.append([i, j])
        elif drop_mode == 'device':
            #device_valid = [i for i in range(n_device) if torch.rand(1) > drop_rate]
            non_agg_device = [x for x in list(range(n_device)) if x not in self.server_idx.tolist()] #KMVFL
            non_agg_valid = [i for i in non_agg_device if torch.rand(1) > drop_rate] #KMVFL
            device_valid = [i for i in (self.server_idx).tolist() if torch.rand(1) > drop_rate] #KMVFL Valid aggregators
            devices_valid = sorted(set(non_agg_valid + device_valid))
            #for j in range(n_device):
            for j in self.server_idx: #KMVFL
                if j in device_valid:
                    #for i in device_valid:
                    for i in devices_valid:
                        if 2 > self.device_dist(i, j):
                            # for grid graph, this distance is 0, 1 or infinity
                            # for rgg this is either 1 or infinity
                            # for grid_soft graph, this distance is always <=1
                            edge_index.append([i, j])
                else:
                    # always keep self-loop for MVFL
                    edge_index.append([j, j])


        elif drop_mode == 'comm_ring':
            edge_index_base = []
            n_receiver, n_sender = n_device, n_device
            for j in range(n_receiver):
                for i in range(j, j + 2):
                    if i == j and j == 0:
                        edge_index_base.append([i, j])
                        edge_index_base.append([n_device - 1, j])  # Remove this line for path graph
                    elif i == j and j != 0:
                        edge_index_base.append([i - 1, j])
                        edge_index_base.append([i, j])
                    else:
                        if i == n_device:
                            edge_index_base.append([0, j])  # Remove this line for path graph
                        else:
                            edge_index_base.append([i, j])

            edge_index_base = [x for x in edge_index_base if x[1] in (self.server_idx).tolist()]
            edgeMask = [-1 for _ in range(len(edge_index_base))]
            for index, ed_set in enumerate(edge_index_base):
                if ed_set[0] != ed_set[1]:
                    if torch.rand(1) > drop_rate:
                        edgeMask[index] = 1
                    else:
                        edgeMask[index] = 0
                else:
                    edgeMask[index] = 1
            edge_index = [elem for elem, flag in zip(edge_index_base, edgeMask) if flag == 1]

            edge_index_base = []
            n_receiver, n_sender = n_device, n_device
            for j in range(n_receiver):
                for i in range(j, j + 2):
                    if i == j and j == 0:
                        edge_index_base.append([i, j])
                        edge_index_base.append([n_device - 1, j])  # Remove this line for path graph
                    elif i == j and j != 0:
                        edge_index_base.append([i - 1, j])
                        edge_index_base.append([i, j])
                    else:
                        if i == n_device:
                            edge_index_base.append([0, j])  # Remove this line for path graph
                        else:
                            edge_index_base.append([i, j])

            edge_index_base = [x for x in edge_index_base if x[1] in (self.server_idx).tolist()]

            non_agg_device = [x for x in list(range(n_device)) if x not in self.server_idx.tolist()]
            non_agg_valid = [i for i in non_agg_device if torch.rand(1) > drop_rate]
            non_agg_invalid = [x for x in non_agg_device if x not in non_agg_valid]
            # non_agg_invalid = non_agg_device - non_agg_valid

            device_valid = [i for i in (self.server_idx).tolist() if torch.rand(1) > drop_rate]
            agg_invalid = [x for x in (self.server_idx).tolist() if x not in device_valid]
            # agg_invalid = (self.server_idx).tolist() - device_valid

            device_invalid = sorted(set(agg_invalid + non_agg_invalid))
            # devices_valid = sorted(set(non_agg_valid + device_valid))

            edgeMask = [1 for _ in range(len(edge_index_base))]
            # device_invalid = [i for i in range(n_device) if torch.rand(1) < drop_rate]
            for device in device_invalid:
                positions = [index for index, sublist in enumerate(edge_index_base) if
                             ((sublist[0] == device or sublist[1] == device) and sublist[1] != sublist[0])]
                for pos in positions:
                    edgeMask[pos] = 0

            edge_index = [elem for elem, flag in zip(edge_index_base, edgeMask) if flag == 1]


        else:
            raise ValueError('drop_mode not recognized')
        edge_index = torch.Tensor(edge_index).T
        # repeat for  all samples in the batch
        edge_index_batch = []
        for i in range(bs):
            edge_index_batch.append(edge_index.clone() + n_device*i)
        edge_index = torch.cat(edge_index_batch, dim=1).long()
        if return_device_valid:
            return edge_index, device_valid
        else:
            return edge_index
    def gossip_gm_test(self, x, edge_index):
        #x = nn.functional.log_softmax(x, dim=-1) #Deactivated it to support the gossiping only during inference
        if self.n_gossip > 0:
            for i in range(self.n_gossip):
                x = self.gossip_passing(x, edge_index)
                x = x - torch.logsumexp(x, dim=-1, keepdim=True)
        return x
    def train_and_test(self):

        #criterion = torch.nn.CrossEntropyLoss()
        criterion = nn.NLLLoss()
        best_val_loss = float('inf')

        for epoch in range(self.epochs):

            # =================================================================================== #
            #                         1. Training                                                 #
            # =================================================================================== #
            total_train_loss = 0
            self.set_model_mode(set_to_train=True)

            for bdx, batch in tqdm(enumerate(self.train_loader), total=len(self.train_loader),desc=f'Epoch {epoch}: Training'):

                self.reset_grad()

                data, label = batch
                data = data.to(self.device)
                label = label.to(self.device)

                if self.model_setup in ['MVFL', 'DeepMVFL', 'DeepMVFL_unconstrain']:
                    # always have full connection during training for now
                    edge_index = self.get_batch_edge_index(drop_rate=self.drop_rate_train, n_device=self.n_device, bs=data.shape[0])
                    edge_index = edge_index.to(self.device)
                    output = self.model(data, edge_index)
                    output = output[:, self.server_idx, :]
                    # output - reshape to use CrossEntropyLoss for high dimensional input
                    loss = criterion(output.permute(0, 2, 1), label.unsqueeze(1).repeat(1, self.k_mvfl))
                elif self.model_setup == 'VFL':
                    edge_index = self.get_batch_edge_index(drop_rate=self.drop_rate_train, n_device=self.n_device, bs=data.shape[0])
                    edge_index = edge_index.to(self.device)
                    output = self.model(data, edge_index)
                    output = output[:, self.server_idx, :]
                    loss = criterion(output, label)
                else:
                    raise NotImplementedError
                loss.backward()
                self.opt.step()

                total_train_loss += loss.item()

            avg_train_loss = total_train_loss / len(self.train_loader)
            if self.wandb:
                wandb.log({'Train/Loss': avg_train_loss,
                           }, step=epoch)

            # =================================================================================== #
            #                         2. Validation                                               #
            # =================================================================================== #

            with torch.no_grad():
                total_val_loss = 0
                self.set_model_mode(set_to_train=False)

                for bdx, batch in tqdm(enumerate(self.val_loader), total=len(self.val_loader),desc=f'Epoch {epoch}: Validating'):

                    data, label = batch
                    data = data.to(self.device)
                    label = label.to(self.device)

                    if self.model_setup in ['MVFL', 'DeepMVFL', 'DeepMVFL_unconstrain']:
                        # always have full connection during training for now
                        edge_index = self.get_batch_edge_index(drop_rate=self.drop_rate_train, n_device=self.n_device, bs=data.shape[0])
                        edge_index = edge_index.to(self.device)
                        output = self.model(data, edge_index)

                        output = output[:, self.server_idx, :] #KMVFL
                        loss = criterion(output.permute(0, 2, 1), label.unsqueeze(1).repeat(1, self.k_mvfl)) #KMVL

                        # reshape to use CrossEntropyLoss for high dimensional input
                        #loss = criterion(output.permute(0, 2, 1), label.unsqueeze(1).repeat(1, self.n_device))
                    elif self.model_setup == 'VFL':
                        edge_index= self.get_batch_edge_index(drop_rate=self.drop_rate_train, n_device=self.n_device,
                                                                           bs=data.shape[0])
                        edge_index = edge_index.to(self.device)
                        output = self.model(data, edge_index)
                        output = output[:, self.server_idx, :]
                        loss = criterion(output, label)
                    else:
                        raise NotImplementedError


                    total_val_loss += loss.item()

                avg_val_loss = total_val_loss / len(self.val_loader)
                if self.wandb:
                    wandb.log({'Val/Loss': avg_val_loss,
                               }, step=epoch)

                if self.save:
                    torch.save(self.model.state_dict(), f'{self.save_dir}/ckpt_{epoch}.pt')
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    if self.save:
                        torch.save(self.model.state_dict(), f'{self.save_dir}/ckpt_best.pt')

            # =================================================================================== #
            #                         3. Test                                                     #
            # =================================================================================== #
            # hard coding for now
            all_dp = [0, 0.1, 0.2, 0.3, 0.4, 0.5]

            self.test(self.drop_mode, all_dp, epoch)

            # save config file
            if not os.path.exists(f'{self.save_dir}/config.yml') and self.save:
                with open(f'{self.save_dir}/config.yml', 'w') as f:
                    yaml.dump(self.config, f)

    def test_MVFL(self, drop_mode, all_dp, epoch):
        with torch.no_grad():
            self.set_model_mode(set_to_train=False)
            tracker = dict()
            tracker_avg_ct = dict()
            oracle_ct = 0

            all_metric = ['avg', 'ora_best', 'ora_worst', 'rand']

            for dp in all_dp:
                tracker[dp] = {}
                tracker_avg_ct[dp] = 0
                for metric in all_metric:
                    tracker[dp][metric] = 0

            for bdx, batch in tqdm(enumerate(self.test_loader), total=len(self.test_loader),desc=f'Epoch {epoch}: Testing'):

                data, label = batch
                data = data.to(self.device)
                label = label.to(self.device)

                for dp in all_dp:
                    # 1. for device drop, keep the same fault device when computing output and aggregating
                    # final prediction
                    # 2. for comm drop, randomly choose some fault communication when aggregating which
                    # is equivalent to device drop in the stage of aggregating final prediction
                    # the main difference would be that for comm drop, device valid is independent of edge index
                    if drop_mode == 'device' or drop_mode == 'device_ring':
                        edge_index, device_valid = self.get_batch_edge_index(drop_rate=dp,
                                                                             n_device=self.n_device,
                                                                             bs=data.shape[0],
                                                                             return_device_valid=True)
                    elif drop_mode == 'comm' or drop_mode == 'comm_ring':
                        edge_index = self.get_batch_edge_index(drop_rate=dp,
                                                               n_device=self.n_device,
                                                               bs=data.shape[0])
                        # determine whether we can get output from certain devices
                        #
                        # device_valid = [i for i in range(self.n_device) if torch.rand(1) > dp]
                        device_valid = [i for i in (self.server_idx).tolist() if torch.rand(1) > dp] #KMVL
                    else:
                        raise ValueError('drop_mode not recognized')

                    edge_index = edge_index.to(self.device)
                    output = self.model(data, edge_index)
                    output = self.gossip_gm_test(output, edge_index) #Added to support the gossiping during the testing



                    if len(device_valid) == 0:
                        device_valid = self.server_idx.tolist() #incase the intersection is empty, assign check_idx to be the k-servers
                        output_avg = output[:, device_valid, :]
                        pred_avg = output_avg.argmax(dim=-1)
                        random_entries = torch.randint(0, self.n_class, size=pred_avg.size()).to(self.device)
                        pred_avg = random_entries
                    else:
                        output_avg = output[:, device_valid, :]
                        pred_avg = output_avg.argmax(dim=-1)

                    tracker[dp]['avg'] += pred_avg.eq(
                        label.unsqueeze(1).repeat(1,pred_avg.shape[1])).sum().item()
                    tracker_avg_ct[dp] += pred_avg.shape[0] * pred_avg.shape[1]

                    # ================ random ======================== #

                    pred = output[:, self.server_idx.tolist(), :].argmax(dim=-1) #Considering KMVFL
                    # for fault device, replace output with random guess
                    #device_invalid = [i for i, idx in enumerate(self.server_idx.tolist()) if idx not in check_idx]
                    device_invalid = [i for i, idx in enumerate(self.server_idx.tolist()) if idx not in device_valid] #KMVFL
                    if len(device_invalid) != 0: #Considering KMVFL
                        pred[:, device_invalid] = torch.randint(0, self.n_class, (pred.shape[0], len(device_invalid))).to(self.device)
                    # randomly choose one device per sample

                    pred_rand = pred[torch.arange(pred.shape[0]),
                                torch.randint(0,pred.shape[1], size=(pred.shape[0],))]
                    tracker[dp]['rand'] += pred_rand.eq(label).sum().item()

                    # ================ oracle best =================== #
                    if pred_avg.shape[1] > 0:
                        tracker[dp]['ora_best'] += pred_avg.eq(label.unsqueeze(1).repeat(1,pred_avg.shape[1])
                                                           ).max(dim=-1)[0].sum().item()

                        # ================ oracle worst =================== #
                        tracker[dp]['ora_worst'] += pred_avg.eq(label.unsqueeze(1).repeat(1, pred_avg.shape[1])
                                                           ).min(dim=-1)[0].sum().item()

                        oracle_ct += pred_avg.shape[0]

            oracle_ct = int(oracle_ct/len(all_dp)) # normalized by number of dropping pattern, this work because counts are the same for different dropping rate
            for dp in all_dp:
                tracker[dp]['avg'] /= tracker_avg_ct[dp]
                for metric in all_metric[1:]:
                    tracker[dp][metric] /= oracle_ct

                if self.wandb:
                    for metric in all_metric:
                        wandb.log({f'Test_{metric}/dp{dp:.1f}': tracker[dp][metric],
                                   }, step=epoch)


    def test_VFL(self, drop_mode, all_dp, epoch):
        with torch.no_grad():
            self.set_model_mode(set_to_train=False)
            tracker = dict()
            tracker_avg_ct = dict()
            for dp in all_dp:
                tracker[dp] = {}
                tracker[dp]['avg'] = 0
                tracker_avg_ct[dp] = 0

            for bdx, batch in tqdm(enumerate(self.test_loader), desc='Testing'):

                data, label = batch
                data = data.to(self.device)
                label = label.to(self.device)

                for dp in all_dp:

                    edge_index = self.get_batch_edge_index(drop_rate=dp,
                                                             n_device=self.n_device,
                                                             bs=data.shape[0])
                    edge_index = edge_index.to(self.device)
                    output = self.model(data, edge_index)
                    output = output[:, self.server_idx, :]
                    # in the VFL setting, device drop and comm drop will be the same

                    # determine if server will be dropped OR
                    # if the final communication of output failed
                    if torch.rand(1) > dp:
                        pred = output.argmax(dim=-1)
                    else:
                        # if dropped, use random guess
                        pred = torch.randint(0, self.n_class, (output.shape[0],)).to(self.device)
                    # ================ average  ====================== #
                    tracker[dp]['avg'] += pred.eq(label).sum().item()
                    tracker_avg_ct[dp] += pred.shape[0]

            for dp in all_dp:
                tracker[dp]['avg'] /= tracker_avg_ct[dp]
                if self.wandb:
                    wandb.log({f'Test_avg/dp{dp:.1f}': tracker[dp]['avg'],
                               }, step=epoch)





