#!/usr/bin/env python
# coding: utf-8

rootdir = './cifar/'

import os
import copy
import math
import pickle
import dill
import random
from math import factorial
from itertools import chain, combinations
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D


# Set Seed
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = get_device()


# Embedding dataset
def get_cifar100_embedded_dataset(rootdir, force_recompute=False):
    # Paths for caching
    emb_train_path = os.path.join(rootdir, 'emb', 'cifar100_train_embeddings.pth')
    emb_test_path = os.path.join(rootdir, 'emb', 'cifar100_test_embeddings.pth')
    label_train_path = os.path.join(rootdir, 'emb', 'cifar100_train_labels.pth')
    label_test_path = os.path.join(rootdir, 'emb', 'cifar100_test_labels.pth')

    raw_data_dir = os.path.join(rootdir, 'rawdata')

    transform = transforms.Compose([transforms.Resize(224),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5071, 0.4867, 0.4408),
                                                         (0.2675, 0.2565, 0.2761))])
    train_dataset = torchvision.datasets.CIFAR100(raw_data_dir, train=True, transform=transform, download=True)
    test_dataset  = torchvision.datasets.CIFAR100(raw_data_dir, train=False, transform=transform, download=True)

    model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
    model.to(device)
    model.eval()
    backbone = nn.Sequential(*list(model.children())[:-1]).to(device)

    if (not force_recompute
        and os.path.exists(emb_train_path) 
        and os.path.exists(emb_test_path)
        and os.path.exists(label_train_path)
        and os.path.exists(label_test_path)):
        print("Loading cached embeddings...")
        X_train = torch.load(emb_train_path)
        X_test = torch.load(emb_test_path)
        y_train = torch.load(label_train_path)
        y_test = torch.load(label_test_path)
    else:
        print("Computing embeddings from pretrained model...")
        # Extract embeddings
        def compute_embeddings(dataset):
            loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False)
            all_embs = []
            all_labels = []
            with torch.no_grad():
                for imgs, labels in loader:
                    imgs = imgs.to(device)
                    emb = backbone(imgs).squeeze(-1).squeeze(-1) 
                    all_embs.append(emb.cpu())
                    all_labels.append(labels)
            X = torch.cat(all_embs, dim=0) 
            Y = torch.cat(all_labels, dim=0)
            return X, Y

        X_train, y_train = compute_embeddings(train_dataset)
        X_test, y_test = compute_embeddings(test_dataset)

        torch.save(X_train, emb_train_path)
        torch.save(X_test, emb_test_path)
        torch.save(y_train, label_train_path)
        torch.save(y_test, label_test_path)

    train_emb_dataset = torch.utils.data.TensorDataset(X_train, y_train)
    test_emb_dataset  = torch.utils.data.TensorDataset(X_test,  y_test)
    return train_emb_dataset, test_emb_dataset

train_emb_dataset, test_emb_dataset = get_cifar100_embedded_dataset(rootdir, force_recompute=False)


# Models
class MLP(nn.Module):
    def __init__(self, num_hid1=1024, num_hid2=512):
        super().__init__()
        self.linear1 = nn.Linear(2048, num_hid1)
        self.linear2 = nn.Linear(num_hid1, num_hid2)
        self.linear3 = nn.Linear(num_hid2, 100)
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.3)

    def forward(self, x):
        x = x.flatten(start_dim=1)
        x = F.relu(self.linear1(x))
        x = self.dropout1(x)
        x = F.relu(self.linear2(x))
        x = self.dropout2(x)
        x = self.linear3(x)
        return x
    
def test_wrapper(test_loader):
    def test(model):
        correct = 0
        loss = 0
        total = 0
        device = next(model.parameters()).device
        model.eval()
        with torch.no_grad():
            for (x, y) in test_loader:
                x, y = x.to(device), y.to(device)
                o = model(x)
                loss += nn.CrossEntropyLoss(reduction='sum')(o, y).item()
                p = o.softmax(dim=1)
                p = p.argmax(dim=1).to(y.dtype)
                p = (p == y).float()
                correct += p.sum()
                total += len(y)
        acc = (correct / total).item()
        loss = loss / total
        return (acc, loss)
    return test

test_loader = torch.utils.data.DataLoader(test_emb_dataset, batch_size=128, shuffle=False)
test = test_wrapper(test_loader)

def train(train_set, model_save_path=None):
    batch_size = 128
    lr = 1e-4
    epochs = 20
    device = get_device()
    criterion = nn.CrossEntropyLoss()
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    mlp = MLP()
    mlp.to(device)
    optimizer = optim.Adam(mlp.parameters(), lr=lr, weight_decay=1e-4)
    best_acc = 0
    best_model = copy.deepcopy(mlp)
    for epoch in range(epochs):
        mlp.train()
        for (x, y) in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            o = mlp(x)
            loss = criterion(o, y)
            loss.backward()
            optimizer.step()
        (test_acc, test_loss) = test(mlp)
        # print(f"Epoch {epoch}: Test accuracy {test_acc:.4f}")
        if test_acc > best_acc:
            best_acc = test_acc
            best_model = copy.deepcopy(mlp)
            if model_save_path is not None:
                torch.save(mlp.state_dict(), model_save_path)
    return best_model


# split into 10 parties
def get_party_labels(n=10):
    all_labels = list(range(100))
    random.shuffle(all_labels)
    chunk_size = len(all_labels) // n

    party_labels = []
    start = 0
    for i in range(n):
        end = start + chunk_size
        if i == n - 1:
            end = len(all_labels)
        party_labels.append(all_labels[start:end])
        start = end
    return party_labels

# return a list of lists of indices, where p_indices[i] are all indices for party i's classes.
def get_party_indices(train_dataset, party_labels):
    _, label_array = train_dataset.tensors
    p_indices = []
    for classes_for_this_party in party_labels:
        idxs = []
        cls_set = set(classes_for_this_party)
        for i, label in enumerate(label_array):
            if int(label) in cls_set:
                idxs.append(i)
        p_indices.append(idxs)
    return p_indices

# p_labels = get_party_labels(10)
# p_indices = get_party_indices(train_emb_dataset, p_labels)
with open(rootdir + 'p_indices.pkl', 'rb') as f:
    p_indices = pickle.load(f)
with open(rootdir + 'p_labels.pkl', 'rb') as f:
    p_labels = pickle.load(f)


# memoized version of characteristic function

# v = {}
with open(rootdir + 'v.pkl', 'rb') as f:
    v = pickle.load(f)

# memoized version
def get_v_c(c):
    cset = frozenset(c)
    if cset in v:
        return v[cset]
    else:
        if len(c) == 0:
            v[cset] = 0.0
            return 0.0
        D_c_indices = list(chain.from_iterable([p_indices[i] for i in c]))
        D_c = torch.utils.data.Subset(train_emb_dataset, D_c_indices)
        model = train(D_c)  
        (acc, _) = test(model) 
        v[cset] = acc
        return acc

def get_v_dual(c, get_v):
    cset = frozenset(c)
    nset = frozenset(range(10))
    return get_v(nset) - get_v(nset - cset)


# MC sampling for shapley
def get_shapley_values(parties, get_v, num_samples=1000):
    shapleys = {party: 0.0 for party in parties}

    for _ in tqdm(range(num_samples)):
        perm = parties[:]
        random.shuffle(perm)

        current_coalition = frozenset()
        current_value = get_v(current_coalition)

        for party in perm:
            new_coalition = current_coalition.union({party})
            new_value = get_v(new_coalition)
            
            marginal_contrib = new_value - current_value
            shapleys[party] = shapleys[party] + marginal_contrib

            current_coalition = new_coalition
            current_value = new_value
    for party in parties:
        shapleys[party] = shapleys[party] / num_samples

    with open(rootdir + 'v.pkl', 'wb') as f:
        pickle.dump(v, f)

    return shapleys


# precompute shapley values for 9 parties except the first party
shapley9 = get_shapley_values(list(range(1, 10)), lambda c: get_v_dual(c, get_v_c), num_samples=100)
with open(rootdir + 'shapley9.pkl', 'wb') as f:
    pickle.dump(shapley9, f)

# precompute shapley values for all 10 parties
shapley10 = get_shapley_values(list(range(10)), lambda c: get_v_dual(c, get_v_c), num_samples=100)
with open(rootdir + 'shapley10.pkl', 'wb') as f:
    pickle.dump(shapley10, f)

# pre compute the scaling factor since it's the same for both methods
with open(rootdir + 'shapley10.pkl', 'rb') as f:
    shapley10 = pickle.load(f)
scaling_factor = get_v_dual(frozenset(range(10)), get_v_c) / max(shapley10.values())

with open(rootdir + 'scale.pkl', 'wb') as f:
    pickle.dump(scaling_factor, f)

def scale_rewards(scaling_factor, rewards):
    return {party: rewards[party] * scaling_factor for party in rewards.keys()}


# Reward value for beta
# for simplicity, only the first party's time value t varies, the time values of the rest are all zero
def get_reward_beta(get_v, t, beta):
    parties = list(range(10))
    rewards = {party: 0.0 for party in parties}

    w_sum = 0.0
    # tau < t
    for tau in range(t):
        w_tau = beta**tau
        w_sum += w_tau
        rewards[0] = rewards[0] + w_tau * get_v(frozenset([0]))
        for party in range(1, 10):
            rewards[party] = rewards[party] + w_tau * shapley9[party]
    # tau == t
    w_tau = beta**t
    w_sum += w_tau
    for party in parties:
        rewards[party] = rewards[party] + w_tau * shapley10[party]
    
    for party in parties:
        rewards[party] = rewards[party] / w_sum

    return scale_rewards(scaling_factor, rewards)


# Get the results for different t 
rewards_beta = []

for t in range(5):
    rewards = get_reward_beta(lambda c: get_v_dual(c, get_v_c), t, 1)
    rewards_beta.append(rewards)

with open(rootdir + 'rewards_beta.pkl', 'wb') as f:
    pickle.dump(rewards_beta, f)

v_N = get_v_dual(frozenset(range(10)), get_v_c)
v_max = max(get_v_dual(frozenset([i]), get_v_c) for i in range(10))
v_extrema = (v_N, v_max)

with open(rootdir + 'v_extrema.pkl', 'wb') as f:
    pickle.dump(v_extrema, f)


# Reward value for gamma

# efficient computation of time-aware valuation function
def get_v_t(get_v, gamma, c, t):
    if t == 0 or 0 in c:
        return get_v(c)
    coop = math.exp(-gamma * t)
    c_no_first_party = [i for i in c if i != 0]
    return get_v(c_no_first_party) * (1 - coop) + get_v(c) * coop + get_v([0]) * (1 - coop)

def get_reward_gamma(get_v, t, gamma):
    parties = list(range(10))
    v_t = lambda c: get_v_t(get_v, gamma, c, t)
    rewards = get_shapley_values(parties, v_t, num_samples=1000)
    return scale_rewards(scaling_factor, rewards)

# Get the results for different t 
rewards_gamma = []

for t in range(5):
    rewards = get_reward_gamma(lambda c: get_v_dual(c, get_v_c), t, 1)
    rewards_gamma.append(rewards)

with open(rootdir + 'rewards_gamma.pkl', 'wb') as f:
    pickle.dump(rewards_gamma, f)

