from __future__ import absolute_import, division, print_function

import os
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
from random import random
from datetime import datetime

from functions import PIDController, Generator, calculate_t, MLP, linear_interpolation, calculate_distances
from classic_auction import alpha_VCG_Mechanism, alpha_GSP_Mechanism
from models import Args, Score_VCG, Hyper_VCG

device = t.device("cuda" if t.cuda.is_available() else "cpu")

class Learner:
    """Two Player Auction Learner."""

    def __init__(self, args):
        self.args = args
        self.auct_model = Hyper_VCG(args)
        
        generator_vector = 4
        self.w_net = MLP([generator_vector, generator_vector*10, generator_vector*10, 4], t.tanh).to(device)
        self.b_net = MLP([generator_vector, generator_vector*10, generator_vector*10, 4], t.tanh).to(device)
        
        self.optimizer_w = t.optim.Adam(self.w_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.scheduler_w = t.optim.lr_scheduler.StepLR(self.optimizer_w, 1, gamma=0.9999, last_epoch=-1)
        
        self.optimizer_b = t.optim.Adam(self.b_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.scheduler_b = t.optim.lr_scheduler.StepLR(self.optimizer_b, 1, gamma=0.9999, last_epoch=-1)
        
    def update_auction(self):
        low, high = 0., 2. * np.random.uniform() + 0.5
        alpha, beta = 1.5 * np.random.uniform(), 8 * np.random.uniform()
        distribution = t.tensor([low, high, alpha, beta]).to(device)
        
        generator = Generator(self.args)
        ctr_ads, ctr_og = generator.generate_uniform(0, 1.), generator.generate_uniform(0, 2.5)
        cvr_ads, cvr_og = generator.generate_uniform(0, 0.17), generator.generate_uniform(0, 0.17)
        train_data = generator.generate_uniform(low, high)
        
        loss = -t.mean(self.auct_model(train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, alpha, beta, self.w_net(distribution), self.b_net(distribution))[0])
        
        self.optimizer_w.zero_grad()
        self.optimizer_b.zero_grad()
        loss.backward()
        self.optimizer_w.step()
        self.optimizer_b.step()
        self.scheduler_w.step()
        self.scheduler_b.step()

def filter_pairs(x, y):
    """Filter pairs (x, y) to remove elements where (x_p > x_k and y_p > y_k)."""
    to_keep = np.ones(len(x), dtype=bool)
    
    for i in range(len(x)):
        if x[i] < 0 or y[i] < 0:
            to_keep[i] = False
            continue
        for j in range(len(x)):
            if i != j and x[j] > x[i] and y[j] > y[i]:
                to_keep[i] = False
                break
    
    return x[to_keep], y[to_keep]

def train_linear(args, args2, nets, rollouts, learner):
    hyper_losspr1, hyper_losscost1, hyper_lossclick1, hyper_losscvr1 = [0], [0], [0], [0]
    hyper_losspr2, hyper_losscost2, hyper_lossclick2, hyper_losscvr2 = [0], [0], [0], [0]
    hyper_perc1, hyper_perc2 = [0], [0]
    
    score_losspr1, score_losscost1, score_lossclick1, score_losscvr1 = [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]]


    generator1 = Generator(args)
    generator2 = Generator(args2)

    point = np.ones([1,6])
    
    alpha, alpha2 = 0.5, 0.5
    beta, beta2 = 0.01, 0.01
    
    PID1 = PIDController(0.05, 0.002, 1., 0.5)
    PID12 = PIDController(0.5, 0.02, 1., 0.103)
    PID2 = PIDController(0.05, 0.002, 1., 0.5)
    PID22 = PIDController(0.5, 0.02, 1, 0.103)

    for i in range(rollouts):
        if i < 100:
            learner.update_auction()
            low, high = 0., 0.5 + 0.5 * 1
            distribution = t.tensor([low, high, 0.5, 0.1]).float().to(device)
            train_data2 = generator2.generate_uniform(low, high)
            ctr_ads2, ctr_og2 = generator2.generate_uniform(0, 1.), generator2.generate_uniform(0, 2.5)
            cvr_ads2, cvr_og2 = generator2.generate_uniform(0, 0.17), generator2.generate_uniform(0, 0.17)
            revenue, cost, click, perc, cvr = learner.auct_model(train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, alpha, beta, learner.w_net(distribution), learner.b_net(distribution))
            rev = revenue.cpu().detach().numpy()
            print(rev)

            for l in range(5):
                train_data = generator1.generate_uniform(0, 1.5)
                ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
                cvr_ads, cvr_og = generator1.generate_uniform(0, 0.17), generator1.generate_uniform(0, 0.17)
                nets[l].seller_backward(args, train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], [0.4, 1.4, 2.2, 4.5, 7][l])

            for l in range(5):
                high = calculate_t(i)
                train_data = generator1.generate_uniform(0, high)
                ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
                cvr_ads, cvr_og = generator1.generate_uniform(0, 0.17), generator1.generate_uniform(0, 0.17)
                nets[l + 5].seller_backward(args, train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], [0.4, 1.4, 2.2, 4.5, 7][l])
        else:
            for _ in range(2):
                learner.update_auction()

            low, high = 0., calculate_t(i)
            distribution = t.tensor([low, high, alpha, beta]).float().to(device)
            train_data2 = generator2.generate_uniform2(low, high)
            ctr_ads2, ctr_og2 = generator2.generate_uniform(0, 1.), generator2.generate_uniform(0, 2.5)
            cvr_ads2, cvr_og2 = generator2.generate_uniform(0, 0.17), generator2.generate_uniform(0, 0.17)
            revenue, cost, click, perc, cvr = learner.auct_model(train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, alpha, beta, learner.w_net(distribution), learner.b_net(distribution))
            rev, cost, click, cvr = revenue.cpu().detach().numpy(), cost.cpu().detach().numpy(), click.cpu().detach().numpy(), cvr.cpu().detach().numpy()
            hyper_losspr1.append(rev)
            hyper_losscost1.append(cost)
            hyper_lossclick1.append(click)
            hyper_losscvr1.append(cvr)
            hyper_perc1.append(perc)
            alpha *= np.exp(PID1.update(perc))
            beta *= np.exp(PID12.update(cvr))

            high = 1.5
            distribution = t.tensor([low, high, alpha2, beta2]).float().to(device)
            revenue, cost, click, perc, cvr = learner.auct_model(train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, alpha2, beta2, learner.w_net(distribution), learner.b_net(distribution))
            rev, cost, click, cvr = revenue.cpu().detach().numpy(), cost.cpu().detach().numpy(), click.cpu().detach().numpy(), cvr.cpu().detach().numpy()
            hyper_losspr2.append(rev)
            hyper_losscost2.append(cost)
            hyper_lossclick2.append(click)
            hyper_losscvr2.append(cvr)
            hyper_perc2.append(perc)
            alpha2 *= np.exp(PID2.update(perc))
            beta2 *= np.exp(PID22.update(cvr))

            for l in range(5):
                train_data = generator1.generate_uniform(0, 1.5)
                ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
                cvr_ads, cvr_og = generator1.generate_uniform(0, 0.17), generator1.generate_uniform(0, 0.17)
                if i % 1 == 0:
                    nets[l].seller_backward(args, train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], [0.4, 1.4, 2.2, 4.5, 7][l])
                    revenue, click, cost, cvr = nets[l](train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, [0.01, 0.1, 0.2, 0.5, 1][l], [0.4, 1.4, 2.2, 4.5, 7][l])
                    score_losspr1[l].append(revenue.cpu().detach().numpy())
                    score_losscost1[l].append(cost.cpu().detach().numpy())
                    score_lossclick1[l].append(click.cpu().detach().numpy())
                    score_losscvr1[l].append(cvr.cpu().detach().numpy())

            high = calculate_t(i)
            train_data = generator1.generate_uniform(0, high)
            for l in range(5):
                ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
                cvr_ads, cvr_og = generator1.generate_uniform(0, 0.17), generator1.generate_uniform(0, 0.17)
                if i % 1 == 0:
                    nets[l + 5].seller_backward(args, train_data, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], [0.4, 1.4, 2.2, 4.5, 7][l])
                    revenue, click, cost, cvr = nets[l + 5](train_data2, ctr_ads2, ctr_og2, cvr_ads2, cvr_og2, [0.01, 0.1, 0.2, 0.5, 1][l], [0.4, 1.4, 2.2, 4.5, 7][l])
                    score_losspr1[l+5].append(revenue.cpu().detach().numpy())
                    score_losscost1[l+5].append(cost.cpu().detach().numpy())
                    score_lossclick1[l+5].append(click.cpu().detach().numpy())
                    score_losscvr1[l+5].append(cvr.cpu().detach().numpy())

        
        if i > 100:

            value_ads = generator1.generate_uniform(0, 1.5)
            ctr_ads, ctr_og = generator1.generate_uniform(0, 1.), generator1.generate_uniform(0, 2.5)
            cvr_ads, cvr_og = generator1.generate_uniform(0, 0.17), generator1.generate_uniform(0, 0.17)
            x = np.array([hyper_lossclick1[-1], hyper_losscost1[-1]])
            xx = np.array([hyper_lossclick2[-1], hyper_losscost2[-1]])
            A1 = np.column_stack((np.array(score_lossclick1)[:5, i-100], np.array(score_losscost1)[:5, i-100]))
            A2 = np.column_stack((np.array(score_lossclick1)[5:, i-100], np.array(score_losscost1)[5:, i-100]))
            x_point, y_point, z_point = np.zeros(8), np.zeros(8), np.zeros(8)
            x_point2, y_point2, z_point2 = np.zeros(8), np.zeros(8), np.zeros(8)
            for l in range(5):
                x_point[l], y_point[l], z_point[l] = alpha_VCG_Mechanism(value_ads, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], [0.6, 2, 3, 5.5, 8][l])
                x_point2[l], y_point2[l], z_point2[l] = alpha_GSP_Mechanism(value_ads, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], [0.7, 2.2, 3.2, 6.3, 8.5][l])
            
            A3 = np.column_stack((x_point,y_point))
            A4 = np.column_stack((x_point2,y_point2))
            add_point = np.zeros([1,6])
            add_point[0][0],add_point[0][1],add_point[0][2],add_point[0][3],add_point[0][4],add_point[0][5]=calculate_distances(x, xx, A1, A2, A3, A4)

            #print(add_point)
            add_point = add_point/add_point[0][4]

            point = np.vstack((point, add_point))

            if i % 200 == 0 and i > 100:
                plot_interpolation2(point, i)

        if i % 2 == 0:
            print('i=', i)
        if i % 100 == 0 and i > 100:
            plot_results(hyper_losspr1, hyper_perc1, 'revenue', 'percentage')
        if i % 200 == 0 and i > 100:
            plot_experiment_results(generator1, nets, hyper_losspr1, hyper_losscost1, hyper_lossclick1, hyper_losscvr1, hyper_losspr2, hyper_losscost2, hyper_lossclick2, hyper_losscvr2, hyper_perc1, hyper_perc2, i, score_losspr1, score_losscost1, score_lossclick1, score_losscvr1)
        if i % 1000 == 0:
            save_models(learner, nets)

def plot_interpolation2(points, step):

    points = np.array(points)

    points[:, :] *= 100

    points[points > 100] = 100

    def aggregate_data(data):
        aggregated_mean = []
        aggregated_lower = []
        aggregated_upper = []
        for i in range(0, len(data), 10):
            chunk = data[i:i+10]
            mean = np.mean(chunk)
            lower = np.percentile(chunk, 20)
            upper = np.percentile(chunk, 80)
            aggregated_mean.append(mean)
            aggregated_lower.append(lower)
            aggregated_upper.append(upper)
        return aggregated_mean, aggregated_lower, aggregated_upper
    
    def avg_data(data):
        data2 = np.zeros(len(data))
        for i in range(len(data)):
            if i == 0:
                data2[i] = data[i]
            else:
                data2[i] = (data2[i-1] * (i) + data[i]) / (i+1)
        return data2
    
    aggregated_points = [aggregate_data(points[:, i]) for i in range(points.shape[1])]
    
    x = range(len(aggregated_points[0][0]))  
    
    plt.figure()
    labels = ['AMMD (online)', 'AMMD (offline)', 'SW-VCG (offline)', 'SW-VCG (online)', 'VCG', 'GSP']
    colors = ['brown', 'red', 'gray', 'violet', 'green', 'blue']
    markers = ['d', 'p', 'v', 'x', 'o', 's']
    
    for i, (mean, lower, upper) in enumerate(aggregated_points):
        plt.plot(x, avg_data(mean), color=colors[i], linestyle='-', label=labels[i])
        plt.fill_between(x, avg_data(np.array(mean)-(np.array(mean)-np.array(lower))/2), avg_data(np.array(mean)+(np.array(upper)-np.array(mean))/2), color=colors[i], alpha=0.1)
        plt.scatter(x[::10], avg_data(mean)[::10], marker=markers[i], color=colors[i])  # Add markers every 10 points
    
    plt.xlabel('Traffic Samples')
    plt.ylabel('Utopia Distance (%)')
    plt.title('Utopia Distance of Different Mechanisms')
    plt.legend()
    plt.show()

    for i, (mean, _, _) in enumerate(aggregated_points):
        print(f'{labels[i]}  {mean[-1]}')

def plot_results(losspr, perc, label1, label2):
    plt.plot(losspr[-48:], label=label1)
    plt.plot(perc[-48:], label=label2)
    plt.show()

def plot_experiment_results(generator, nets, losspr2, losscost2, lossclick2, losscvr2, losspr4, losscost4, lossclick4, losscvr4, perc2, perc4, i, score_losspr1, score_losscost1, score_lossclick1, score_losscvr1):
    fixed_point = (np.mean(lossclick2[-48:]), np.mean(losscost2[-48:]))
    fixed_point2 = (np.mean(lossclick4[-48:]), np.mean(losscost4[-48:]))
    value_ads = generator.generate_uniform(0, 1.5)
    ctr_ads, ctr_og = generator.generate_uniform(0, 1.), generator.generate_uniform(0, 2.5)
    cvr_ads, cvr_og = generator.generate_uniform(0, 0.17), generator.generate_uniform(0, 0.17)
    x_point, y_point, z_point = np.zeros(8), np.zeros(8), np.zeros(8)
    x_point2, y_point2, z_point2 = np.zeros(8), np.zeros(8), np.zeros(8)
    x_point3, y_point3, z_point3 = np.zeros(8), np.zeros(8), np.zeros(8)
    x_point4, y_point4, z_point4 = np.zeros(8), np.zeros(8), np.zeros(8)

    for l in range(5):
        x_point[l], y_point[l], z_point[l] = alpha_VCG_Mechanism(value_ads, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], [0.6, 2, 3, 5.5, 8][l])
        x_point2[l], y_point2[l], z_point2[l] = alpha_GSP_Mechanism(value_ads, ctr_ads, ctr_og, cvr_ads, cvr_og, [0.01, 0.1, 0.2, 0.5, 1][l], [0.7, 2.2, 3.2, 6.3, 8.5][l])
        x_point3[l] = np.mean(score_lossclick1[l][-48:])
        y_point3[l] = np.mean(score_losscost1[l][-48:])
        z_point3[l] = np.mean(score_losscvr1[l][-48:])
        x_point4[l] = np.mean(score_lossclick1[l+5][-48:])
        y_point4[l] = np.mean(score_losscost1[l+5][-48:])
        z_point4[l] = np.mean(score_losscvr1[l+5][-48:])

    print('x_point:', x_point)
    print('y_point:', y_point)
    print('z_point:', z_point)
    print('x_point2:', x_point2)
    print('y_point2:', y_point2)
    print('z_point2:', z_point2)
    print('x_point3:', x_point3)
    print('y_point3:', y_point3)
    print('z_point3:', z_point3)
    print('x_point4:', x_point4)
    print('y_point4:', y_point4)
    print('z_point4:', z_point4)

    x_point, y_point = filter_pairs(x_point, y_point)
    x_point2, y_point2 = filter_pairs(x_point2, y_point2)
    x_point3, y_point3 = filter_pairs(x_point3, y_point3)
    x_point4, y_point4 = filter_pairs(x_point4, y_point4)

    plt.figure(dpi=600)
    plt.plot(x_point, y_point, marker='o', linestyle='-', color='green', label='VCG ')
    plt.plot(x_point2, y_point2, marker='s', linestyle='-', color='blue', label='GSP')
    plt.plot(x_point3, y_point3, marker='v', linestyle='-', color='gray', label='SW-VCG (offline)')
    plt.plot(x_point4, y_point4, marker='x', linestyle='-', color='violet', label='SW-VCG (online)')
    plt.scatter(*fixed_point2, color='red', marker='p', label='AMMD (offline)')
    plt.scatter(*fixed_point, color='brown', marker='d', label='AMMD (online)')
    plt.xlabel('click')
    plt.ylabel('cost')
    plt.legend()
    plt.title('Experiments in dynamic traffic environment (CVR)')

    # Save the plot as PDF
    save_plot()

    plt.show()

    print('AMMD (online):', fixed_point)
    print('AMMD (online CVR):', np.mean(losscvr2[-48:]))
    print('AMMD (offline):', fixed_point2)
    print('AMMD (offline CVR):', np.mean(losscvr4[-48:]))
    print('orgs_perc:', np.mean(perc2[-48:]))
    print('orgs_perc_max:', np.max(perc2[-48:]))
    print('orgs_perc_min:', np.min(perc2[-48:]))
    print('orgs_percoff:', np.mean(perc4[-48:]))
    print('orgs_perc_maxoff:', np.max(perc4[-48:]))
    print('orgs_perc_minoff:', np.min(perc4[-48:]))

def save_plot():
    """Save the current plot to the imgs directory with a timestamp."""
    if not os.path.exists('imgs'):
        os.makedirs('imgs')
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    plt.savefig(f'imgs/plot_{timestamp}.pdf')

def save_models(learner, nets):
    """Save the current learner and nets to the checkpoints directory."""
    if not os.path.exists('checkpoints'):
        os.makedirs('checkpoints')
    t.save(learner, 'checkpoints/learner.pth')
    for idx, net in enumerate(nets):
        t.save(net, f'checkpoints/net_{idx}.pth')

def load_models():
    """Load the learner and nets from the checkpoints directory."""
    learner = t.load('checkpoints/learner.pth')
    nets = [t.load(f'checkpoints/net_{idx}.pth') for idx in range(10)]
    return learner, nets

if __name__ == "__main__":
    # Set this flag to True to load models from checkpoints
    LOAD_FROM_CHECKPOINT = False

    args = Args((4, 1, "uniform", 10, 10, 2000, 2000, 1))
    args2 = Args((4, 1, "uniform", 10, 10, 2000, 2000, 1))
    nets = [Score_VCG(args) for _ in range(10)]

    if LOAD_FROM_CHECKPOINT:
        learner, nets = load_models()
    else:
        learner = Learner(args2)

    train_linear(args, args2, nets, rollouts=100001, learner=learner)