import copy
import math
import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import metric
from loss import feature_loss, ContrastiveLoss
from net import Network
from utils import batch_compute_similarity_tensor, valid_by_feature, compute_acc, clustering_acc, \
    most_similar_per_sample, all_combinations, draw_scatter, valid_by_labels

class Client():
    def __init__(self, id, all_view, have_view, data, args, device, writer):
        self.id = id
        self.all_view = all_view
        self.have_view = have_view
        self.data = data
        self.args = args
        self.device = device
        self.writer = writer
        self.net = Network(id, all_view, have_view, args, device).float()
        self.name = str(self.id) + str(self.have_view)
        self.class_num = args.class_num

        xs = self.data.dataset.dataset_x
        ys = torch.tensor(self.data.dataset.dataset_y)
        for v in self.all_view:
            xs[v] = torch.tensor(xs[v]).to(self.device).reshape(xs[v].shape[0], -1)
            xs[v] = xs[v].to(torch.float32)

        self.xs = xs
        self.ys = ys
        self.label = [0] * len(ys)
        self.cl = ContrastiveLoss(0.5)

        self.means = []
        self.stds = []
        with torch.no_grad():
            self.net.eval()
            xrs, zs, hs, h, ws, pt, labels, centers = self.net(xs, self.have_view)
            for v in self.all_view:
                self.means.append(zs[v].mean(dim=0))
                self.stds.append(zs[v].std(dim=0, unbiased=False))
                # print(zs[v].std(dim=0, unbiased=False))

    def train(self, epos):
        # if self.id == 6:
        #     print(123)
        # xs = self.add_gaussian_noise(self.xs)
        xs = self.xs
        optimizer = torch.optim.Adam(self.net.parameters(), lr=self.args.learning_rate,
                                     weight_decay=self.args.weight_decay)
        self.net.train()
        for epo in range(epos):
            give_views = all_combinations(self.have_view)
            # give_views = [self.have_view]
            for give_view in give_views:
                xrs, zs, hs, h, ws, pt, labels, centers = self.net(xs, give_view)

                loss_list1 = []
                for v in self.have_view:
                    loss_list1.append(F.mse_loss(xs[v], xrs[v]))
                loss1 = sum(loss_list1) / len(loss_list1)

                loss_list2 = []
                for v in self.have_view:
                    same_prob = torch.eye(xs[v].shape[0]).cuda()
                    cos_sim = F.cosine_similarity(zs[v].unsqueeze(1), h.unsqueeze(0), dim=2)
                    cos_sim = torch.clamp(cos_sim, -1.0, 1.0)
                    sim2 = cos_sim
                    # sim2 = 1 - 2 * torch.sqrt(0.5 * (1 - cos_sim))
                    # sim2 = torch.sqrt(0.5 * (1 + cos_sim)) * 2 - 1
                    posi_sim = torch.sum(same_prob * torch.exp(sim2 / self.args.temperature), dim=1)
                    nega_sim = torch.sum((1 - same_prob) * torch.exp(sim2 / self.args.temperature), dim=1)
                    loss_list2.append(- torch.sum(torch.log(posi_sim / nega_sim)) / same_prob.shape[0])
                loss2 = sum(loss_list2) / len(loss_list2)

                acc_rate = self.args.acc_rate
                diag = torch.eye(self.class_num, self.class_num).to(self.device)
                aver = torch.ones(self.class_num, self.class_num).to(self.device) / self.class_num
                cor_goal = acc_rate * diag + (1 - acc_rate) * aver
                probs = pt
                # probs = F.softmax(zs[v] @ torch.randn(self.args.output_dim, self.class_num).cuda())
                cor = (probs.T @ probs) * self.class_num / h.shape[0]
                loss3 = F.mse_loss(cor, cor_goal)

                loss = loss1 + self.args.p1 * loss2 + self.args.p2 * loss3

                loss.backward()
                optimizer.step()
                acc = valid_by_feature(self.ys, h, self.args.class_num)

        accs_v = [0] * len(self.all_view)
        _, acc_h = 0, 0
        acc_h = acc
        return accs_v, acc_h, h

    def train1(self):
        accs_v, acc_h, h = self.train(50)
        # for v in self.have_view:
        #     self.writer.add_scalar(self.name + '/view1_' + str(v), accs_v[v], 0)
        # self.writer.add_scalar(self.name + '/view1_h', acc_h, 0)

    def train2(self, nets, i):
        # if i == 200 :
        #     print(123)
        with torch.no_grad():
            fws = []
            o_xrs, o_zs, o_hs, o_h, o_ws, o_pt, o_labels, o_centers = self.net(self.xs, self.have_view)
            for net in nets:
                w = 0
                for v in net.have_view:
                    # if v in self.have_view:
                    xrs, zs, hs, h, ws, pt, labels, centers = net(self.xs, self.have_view)
                    w += self.matrix_euclidean_distance(o_h, zs[v]) / o_h.shape[0]
                fws.append(w)
            # fws = [1 * len(nets)]
            print(fws)
            agg_net = self.aggregate_nets(nets, fws)
            self.net.load_state_dict(agg_net.state_dict())

            # print(fws)

        accs_v, acc_h, h = self.train(5)

        # # for v in self.have_view:
        # #     self.writer.add_scalar(self.name + '/view2_' + str(v), accs_v[v], i)
        # self.writer.add_scalar(self.name + '/view2_h', acc_h, i)

        if self.id == 0:
            draw_scatter(h, self.ys, self.name + '_' + str(i))

    @torch.no_grad()
    def aggregate_nets(self, nets, fws):
        agg_net = copy.deepcopy(nets[0])
        agg_state_dict1 = agg_net.state_dict()
        for key in agg_state_dict1:
            agg_state_dict1[key].zero_()
        for net, weight in zip(nets, fws):
            net_state_dict1 = net.state_dict()
            for key in agg_state_dict1:
                agg_state_dict1[key] += net_state_dict1[key] * weight / sum(fws)
        agg_net.load_state_dict(agg_state_dict1)
        return agg_net

    @torch.no_grad()
    def get_y_h(self):
        self.net.eval()
        xs = self.xs
        ys = self.ys
        xrs, zs, hs, h, ws, pt, labels, centers = self.net(xs, self.have_view)
        return ys, h

    @torch.no_grad()
    def add_gaussian_noise(self, xs, mean=0.0, std=0.01):
        noise_xs = copy.deepcopy(xs)
        for v in self.all_view:
            noise = torch.randn_like(noise_xs[v]) * std + mean
            noise_xs[v] = F.normalize(noise_xs[v]) + noise
        return noise_xs

    def matrix_euclidean_distance(self, matrix1, matrix2):
        if matrix1.shape != matrix2.shape:
            raise ValueError("err")

        diff = matrix1 - matrix2
        distance = torch.norm(diff, p='fro')

        return distance.item()