import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import random
import os
import copy
import json

import common_plot

seed_num = 37
random.seed(seed_num)
# DPO parameters
beta1=0.01
beta2=0.01
# IPO parameters
tau=0.1
# SLiC parameters
delta=5
mu=0.1

# algorithm_num
# 0: DPO
# 1: IPO
# 2: SliC
# 3: NCA
algorithm_num=0

if algorithm_num == 0:
    base_filepath = "DPO"
elif algorithm_num == 1:
    base_filepath = "IPO"
elif algorithm_num == 2:
    base_filepath = "SLiC"
else:
    raise ValueError("This alogrithm hasn't been implemented yes")


SFT_TRAIN = True

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        return self.layers(x)

input_size = 4
hidden_size = 64
output_size = 10

save_dirs = []
target_distributions = []
# 目标分布
# save_dirs.append(f'chosen_offline_reject_offline_origin_beta1_{str(beta1)}_beta2_{str(beta2)}_tau_{str(tau)}_delta_{str(delta)}_mu_{str(mu)}'); target_distributions.append(torch.tensor([0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.8, 0.04]))
# save_dirs.append(f'chosen_offline_reject_online_origin_beta1_{str(beta1)}_beta2_{str(beta2)}_tau_{str(tau)}_delta_{str(delta)}_mu_{str(mu)}');  target_distributions.append(torch.tensor([0.02, 0.02, 0.02, 0.02, 0.12, 0.12, 0.12, 0.12, 0.4, 0.04]))
# save_dirs.append(f'chosen_online_reject_offline_origin_beta1_{str(beta1)}_beta2_{str(beta2)}_tau_{str(tau)}_delta_{str(delta)}_mu_{str(mu)}'); target_distributions.append(torch.tensor([0.12, 0.12, 0.12, 0.12, 0.02, 0.02, 0.02, 0.02, 0.4, 0.04]))
save_dirs.append(f'chosen_online_reject_online_origin_beta1_{str(beta1)}_beta2_{str(beta2)}_tau_{str(tau)}_delta_{str(delta)}_mu_{str(mu)}'); target_distributions.append(torch.tensor([0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.03, 0.01]))

# save_dirs.append(f'chosen_very_offline_reject_very_offline_origin_beta1_{str(beta1)}_beta2_{str(beta2)}'); target_distributions.append(torch.tensor([0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.9, 0.1]))
# save_dirs.append(f'chosen_online_reject_very_offline_origin_beta1_{str(beta1)}_beta2_{str(beta2)}'); target_distributions.append(torch.tensor([0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.9, 0.1]))
## Supervised Fine-tuning

for save_dir, target_distribution in zip(save_dirs, target_distributions):
    save_dir = f"{base_filepath}/{str(seed_num)}/{save_dir}"
    if SFT_TRAIN:
        model = MLP(input_size, hidden_size, output_size)
        def train(model, target_distribution, epochs=1000, lr=0.001):
            criterion = nn.KLDivLoss(reduction='batchmean')
            optimizer = optim.Adam(model.parameters(), lr=lr)

            inputs = torch.eye(input_size)  # 生成4维的one-hot编码
            log_target_distribution = torch.log(target_distribution + 1e-8)  # 避免对0取对数
            target_distribution = target_distribution.repeat(input_size, 1)

            for epoch in range(epochs):
                model.train()
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, target_distribution)
                if torch.isnan(loss):
                    print(f"NaN loss detected at epoch {epoch}")
                    break
                loss.backward()
                optimizer.step()

                if epoch % 100 == 0:
                    print(f'Epoch {epoch}: Loss = {loss.item()}')

        train(model, target_distribution)
        # 保存模型状态字典
        torch.save(model.state_dict(), 'sft_model.pth')
    else:
        input_size = 4
        hidden_size = 64
        output_size = 10
        model = MLP(input_size, hidden_size, output_size)
        model.load_state_dict(torch.load('sft_model.pth'))

    ## Generate a Preference Dataset
    preferences = {}
    preferences[0] = [0,4]
    preferences[1] = [1,6]
    preferences[2] = [2,7]
    preferences[3] = [3,5]
    # for i in range(4):
    #     preferences[i] = [i, np.random.randint(4,8)]

    common_plot.draw_preference_dataset(preferences, save_dir)

    ## draw the optimal policy via heatmap
    # 创建图像并添加格点
    common_plot.draw_optimal_policy(save_dir)


    ## (DPO) Direct Preference Optimization
    class DPO(nn.Module):
        def __init__(self, sft_model):
            super(DPO, self).__init__()
            self.mlp = sft_model

        def forward(self, x):
            return self.mlp(x)

    def dpo_preference_loss(preferences, model_outputs, reference_model_outputs, beta1=beta1, beta2=beta2):
        # model_outputs, reference_model_outputs: logp
        loss = 0
        # 对每一对偏好进行处理
        for i, pref in enumerate(preferences.values()):
            preferred, non_preferred = pref
            # loss项
            # loss += torch.log(1 + torch.exp(- beta * (model_outputs[i][preferred] -  model_outputs[i][non_preferred] + reference_model_outputs[i][non_preferred] - reference_model_outputs[i][preferred]) ) )
            loss += torch.log(1 + torch.exp( \
                beta1 * (reference_model_outputs[i][preferred] -  model_outputs[i][preferred]) + \
                beta2 * (model_outputs[i][non_preferred] - reference_model_outputs[i][non_preferred]) \
                    ) )
        return loss

    def ipo_preference_loss(preferences, model_outputs, reference_model_outputs, tau=tau):
        loss = 0
        for i, pref in enumerate(preferences.values()):
            preferred, non_preferred = pref
            loss += (
                model_outputs[i][preferred] - model_outputs[i][non_preferred] + reference_model_outputs[i][non_preferred] - reference_model_outputs[i][preferred] - 1/(2*tau)
            )**2
        
        return loss

    def slic_preference_loss(preferences, model_outputs, reference_model_outputs, delta=delta, mu=mu):
        loss = 0
        for i, pref in enumerate(preferences.values()):
            preferred, non_preferred = pref
            loss += torch.maximum(torch.tensor(0), delta - model_outputs[i][preferred] + model_outputs[i][non_preferred]) - mu * model_outputs[i][preferred]
        return loss

    def calculate_dpo_gradient(preferences, model_outputs, reference_model_outputs, beta1=beta1, beta2=beta2):
        model_outputs = model_outputs.detach().numpy()
        reference_model_outputs = reference_model_outputs.detach().numpy()
        chosen_prob_gradients = []
        rejected_prob_gradients = []
        for i, pref in enumerate(preferences.values()):
            preferred, non_preferred = pref
            alpha = np.exp(beta1 * reference_model_outputs[i][preferred] - beta2*reference_model_outputs[i][non_preferred])
            gradient_on_chosen = - alpha*beta1*np.exp((-beta1-1)*model_outputs[i][preferred])*np.exp(beta2*model_outputs[i][non_preferred])
            gradient_on_rejected = alpha*beta2*np.exp(-beta1*model_outputs[i][preferred])*np.exp((beta2-1)*model_outputs[i][non_preferred])
            chosen_prob_gradients.append(gradient_on_chosen)
            rejected_prob_gradients.append(gradient_on_rejected)
        
        return np.mean(chosen_prob_gradients), np.mean(rejected_prob_gradients)

    def calculate_ipo_gradient(preferences, model_outputs, reference_model_outputs, tau=tau):
        model_outputs = model_outputs.detach().numpy()
        reference_model_outputs = reference_model_outputs.detach().numpy()
        chosen_prob_gradients = []
        rejected_prob_gradients = []
        for i, pref in enumerate(preferences.values()):
            preferred, non_preferred = pref
            alpha = 2 * (model_outputs[i][preferred] - model_outputs[i][non_preferred] + reference_model_outputs[i][non_preferred] - reference_model_outputs[i][preferred] - 1/(2*tau))
            gradient_on_chosen = alpha / np.exp(model_outputs[i][preferred])
            gradient_on_rejected = - alpha / np.exp(model_outputs[i][non_preferred])
            chosen_prob_gradients.append(gradient_on_chosen)
            rejected_prob_gradients.append(gradient_on_rejected)

        return np.mean(chosen_prob_gradients), np.mean(rejected_prob_gradients)

    def calculate_slic_gradient(preferences, model_outputs, reference_model_outputs, delta=delta, mu=mu):
        model_outputs = model_outputs.detach().numpy()
        reference_model_outputs = reference_model_outputs.detach().numpy()
        chosen_prob_gradients = []
        rejected_prob_gradients = []
        for i, pref in enumerate(preferences.values()):
            preferred, non_preferred = pref
            if delta > model_outputs[i][preferred] - model_outputs[i][non_preferred]:
                gradient_on_chosen = - (1+mu) / np.exp(model_outputs[i][preferred])
                gradient_on_rejected = - (1) / np.exp(model_outputs[i][non_preferred])
            else:
                gradient_on_chosen = - (mu) / np.exp(model_outputs[i][preferred])
                gradient_on_rejected = 0
            chosen_prob_gradients.append(gradient_on_chosen)
            rejected_prob_gradients.append(gradient_on_rejected)

        return np.mean(chosen_prob_gradients), np.mean(rejected_prob_gradients)

    
    def calculate_nca_gradient(preferences, model_outputs, reference_model_outputs):
        return 

    def switch_loss(case):
        if case == 0:
            return dpo_preference_loss
        elif case == 1:
            return ipo_preference_loss
        elif case == 2:
            return slic_preference_loss
        else:
            raise ValueError("un-implemented algorithm loss")

    def switch_calculate_gradient(case):
        if case == 0:
            return calculate_dpo_gradient
        elif case == 1:
            return calculate_ipo_gradient
        elif case == 2:
            return calculate_slic_gradient
        else:
            raise ValueError("un-implemented algorithm gradient")

    def train_dpo(model, reference_model, preferences, epochs=2000, lr=2.5e-4, save_dir=save_dir):
        os.makedirs(save_dir, exist_ok=True)

        optimizer = optim.Adam(model.parameters(), lr=lr)
        average_chosen_prob = []
        min_chosen_prob = []
        average_reject_prob = []
        max_reject_prob = []
        max_others_prob = []
        all_probability = []
        average_others_prob = []

        chosen_prob_grads = []
        rejected_prob_grads = []
        for epoch in range(epochs):
            model.train()
            optimizer.zero_grad()
            inputs = torch.eye(input_size)  # 生成one-hot编码的输入
            outputs = model(inputs)
            reference_outputs = reference_model(inputs)
            loss = switch_loss(algorithm_num)(preferences, outputs, reference_outputs)
            chosen_prob_grad, rejected_prob_grad = switch_calculate_gradient(algorithm_num)(preferences, outputs, reference_outputs)
            chosen_prob_grads.append(chosen_prob_grad)
            rejected_prob_grads.append(rejected_prob_grad)
            
            loss.backward()
            optimizer.step()

            if epoch % 50 == 0:
                print(f'Epoch {epoch}: Loss = {loss.item()}')
                common_plot.plot_heatmap(epoch, outputs, save_dir)

            if epoch % 1 == 0:
                probabilities = torch.exp(outputs).detach().numpy()
                all_probability.append(probabilities)
                average_chosen_prob.append(np.mean([probabilities[i,i] for i in range(4)]))
                min_chosen_prob.append(np.min([probabilities[i,i] for i in range(4)]))
                average_reject_prob.append(np.mean([probabilities[value[0], value[1]] for value in preferences.values()]))
                max_reject_prob.append(np.max([probabilities[value[0], value[1]] for value in preferences.values()]))
                
                reject_indices = [probabilities[value[0], value[1]] for value in preferences.values()]
                
                unseen_probabilities = []
                for i in range(probabilities.shape[0]):
                    for j in range(probabilities.shape[1]):
                        if (i != j) and (probabilities[i, j] not in reject_indices):
                            unseen_probabilities.append(probabilities[i, j])
                max_others_prob.append(np.max(unseen_probabilities))
                average_others_prob.append(np.mean(unseen_probabilities))

        common_plot.plot_probability(average_chosen_prob, min_chosen_prob, average_reject_prob, max_reject_prob, max_others_prob, average_others_prob, save_dir)
        common_plot.plot_gradient(chosen_prob_grads, rejected_prob_grads, beta1, beta2, save_dir)

        data = {
            'all_prob': np.asarray(all_probability).tolist(),
        }
        json_file_name = 'all_prob_data.json'
        json_file_path = os.path.join(save_dir, json_file_name)
        with open(json_file_path, 'w') as json_file:
            json.dump(data, json_file)

    dpo_model = DPO(model)
    reference_model = copy.deepcopy(model)
    reference_model.eval()
    train_dpo(dpo_model, reference_model, preferences)
    torch.save(dpo_model.state_dict(), 'dpo_model.pth')