import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, Subset
from net.models import build_model, Linear
from utils.process_data import CustomDataset

class Party:
    """
    Base class for users in federated learning.
    """
    def __init__(self, args, id, sub_dataset, test_data):
        if isinstance(sub_dataset, Subset):
            self.dataset = sub_dataset
        elif isinstance(sub_dataset, tuple):
            user_X = torch.tensor(sub_dataset[0], dtype=torch.float32)
            user_y = torch.tensor(sub_dataset[1], dtype=torch.int64)
            self.dataset = TensorDataset(user_X, user_y)
        self.batch_size = args.batch_size
        self.epochs = args.epochs
        self.test_data_size = len(test_data)
        self.train_loader = torch.utils.data.DataLoader(self.dataset,batch_size=self.batch_size, shuffle=True)
        self.test_data = test_data
        self.data_name = args.dataset
        self.id = id

        self.reward_data = []
        self.reward_label = []

        if args.dataset == 'MNIST':
            self.model_type = 'MLP'
        else:
            self.model_type = 'CNN'
        self.device = torch.device(args.cuda_num)
        self.model = build_model(self.model_type, 10)
        self.model.to(self.device)
        self.optimizer=torch.optim.Adam(params=self.model.parameters(),
                            lr=1e-3, betas=(0.9, 0.999))
        self.loss_fn = nn.CrossEntropyLoss()

        # New Model
        self.model_new = build_model(self.model_type, 10)
        self.model_new.to(self.device)
        self.optimizer_new=torch.optim.Adam(params=self.model_new.parameters(),
                            lr=1e-3, betas=(0.9, 0.999))
        self.alpha = 0.5

    
    def ce_loss(self, pred, label):
        # Calculate the cross entropy loss
        loss = -torch.sum(label * torch.log(pred + 1e-8), dim=1)
        return torch.mean(loss)
    
    def construct_data(self):
        reward_X = torch.tensor(self.reward_data)
        reward_y = torch.tensor(self.reward_label)
        reward_dataset = TensorDataset(reward_X, reward_y)
        combined_dataset = CustomDataset(self.dataset, reward_dataset)
        dataloader = DataLoader(combined_dataset, batch_size=self.batch_size, shuffle=True)
        return dataloader
    
    def get_optimal_alpha(self):
        # TODO
        linear_model = Linear(self.data_name)
        linear_model.to(self.device)
        optimizer_linear = torch.optim.Adam(params=linear_model.parameters(),
                            lr=1e-3, betas=(0.9, 0.999))
        dataloader = self.construct_data()
        loss_fn = nn.MSELoss()
        min_loss = 10000000
        epochs = 10
        for epoch in range(epochs):
            loss_sum = 0
            for source_batch, reward_batch in dataloader:
                source_X = source_batch[0]
                reward_X = reward_batch[0]
                output_1 = linear_model(source_X.reshape(source_X.shape[0],-1).to(self.device))
                loss_1 = loss_fn(output_1, torch.zeros_like(output_1).to(self.device))

                output_2 = linear_model(reward_X.reshape(reward_X.shape[0],-1).to(self.device))
                loss_2 = loss_fn(output_2, torch.ones_like(output_1).to(self.device))
                # print(loss_1, loss_2)
                loss = loss_1 + loss_2

                loss_sum += loss.item()

                loss.backward()
                optimizer_linear.step()
                optimizer_linear.zero_grad()
            if loss_sum/len(dataloader) < min_loss:
                min_loss = loss_sum/len(dataloader)
        print('loss: ', min_loss)
        A = - 2 * (1 - min_loss) / 50 + 0.01
        C = A ** 2 / 9
        m = len(self.dataset)
        T = len(self.reward_data)
        print('C: ', C, ' 1/t+1/m: ', 1/m + 1/T)
        if C >= 1/m + 1/T:
            self.alpha = 0
        else:
            self.alpha = T/(m+T) *(1 - m * C**0.5 / (m+T-C*T*m)**0.5)
        print('estimated alpha: ', self.alpha)
    
    def get_emp_optimal_alpha(self):
        # TODO
        alpha_record = []
        for alpha in np.arange(0.1,1,0.1):
            dataloader = self.construct_data()
            model_new = build_model(self.model_type, 10)
            model_new.to(self.device)
            optimizer_new=torch.optim.Adam(params=model_new.parameters(),
                                lr=1e-3, betas=(0.9, 0.999))
            for epoch in range(self.epochs):
                for source_batch, reward_batch in dataloader:
                    source_X, source_y = source_batch[0], source_batch[1]
                    reward_X, reward_y = reward_batch[0], reward_batch[1]
                    # print(source_X.shape, reward_X.shape) # Same Batchsize, Double
                    output_1 = model_new(source_X.to(self.device))
                    loss_1 = self.loss_fn(output_1, source_y.to(self.device))

                    output_2 = F.softmax(model_new(reward_X.to(self.device)),dim=1)
                    loss_2 = self.ce_loss(output_2, reward_y.to(self.device))
                    # print(loss_1, loss_2)
                    loss = (1-alpha) * loss_1 + alpha * loss_2
                    loss.backward()

                    optimizer_new.step()
                    optimizer_new.zero_grad()
            test_loader = torch.utils.data.DataLoader(self.test_data,batch_size=self.test_data_size, shuffle=False)
            with torch.inference_mode():
                test_correct = 0
                for x, y in test_loader:
                    x = x.to(self.device)
                    y =y.to(self.device)
                    output = model_new(x)
                    test_correct += (torch.sum(torch.argmax(output, dim=1) == y)).item()
                acc = test_correct / self.test_data_size
            alpha_record.append((alpha, acc))
        return alpha_record
    
    def train_new(self):
        self.get_optimal_alpha()
        dataloader = self.construct_data()
        for epoch in range(self.epochs):
            for source_batch, reward_batch in dataloader:
                source_X, source_y = source_batch[0], source_batch[1]
                reward_X, reward_y = reward_batch[0], reward_batch[1]
                # print(source_X.shape, reward_X.shape) # Same Batchsize, Double
                output_1 = self.model_new(source_X.to(self.device))
                loss_1 = self.loss_fn(output_1, source_y.to(self.device))

                output_2 = F.softmax(self.model_new(reward_X.to(self.device)),dim=1)
                loss_2 = self.ce_loss(output_2, reward_y.to(self.device))
                # print(loss_1, loss_2)
                loss = (1-self.alpha) * loss_1 + self.alpha * loss_2
                loss.backward()

                self.optimizer_new.step()
                self.optimizer_new.zero_grad()

    def get_pred(self, data):
        with torch.inference_mode():
            output = self.model(data)
            output = F.softmax(output, dim=1)
        return output.cpu().numpy()
    
    def train(self):
        for epoch in range(self.epochs):
            # iterate over batches
            for local_X, local_y  in self.train_loader:
                output = self.model(local_X.to(self.device))
                loss = self.loss_fn(output, local_y.to(self.device))
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

    def test(self):
        test_loader = torch.utils.data.DataLoader(self.test_data,batch_size=self.test_data_size, shuffle=False)
        with torch.inference_mode():
            test_correct = 0
            for x, y in test_loader:
                x = x.to(self.device)
                y =y.to(self.device)
                output = self.model(x)
                test_correct += (torch.sum(torch.argmax(output, dim=1) == y)).item()
        return test_correct / self.test_data_size
    
    def test_new(self):
        test_loader = torch.utils.data.DataLoader(self.test_data,batch_size=self.test_data_size, shuffle=False)
        with torch.inference_mode():
            test_correct = 0
            for x, y in test_loader:
                x = x.to(self.device)
                y =y.to(self.device)
                output = self.model_new(x)
                test_correct += (torch.sum(torch.argmax(output, dim=1) == y)).item()
        return test_correct / self.test_data_size
    
    def save_model(self):
        model_path = os.path.join("new_models", self.dataset)
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        torch.save(self.model, os.path.join(model_path, "user_" + self.id + ".pt"))