from __future__ import absolute_import, division, print_function

import numpy as np
import torch as t
from matplotlib import pyplot as plt
from functions import PIDController, Generator, Generator2, MLP, calculate_t, Args, filter_pairs
from models import Line_Net, Hyper_Myerson
import os

device = t.device("cpu")

args = Args((2, 1, "uniform", 20, 20, 10000, 10000, 1))

class Learner:
    """Two Player Auction Learner."""

    def __init__(self, args):
        self.args = args
        self.auct_model = Hyper_Myerson(args)
        generator_vector = 3

        self.w_net = MLP([generator_vector, generator_vector * 10, generator_vector * 10, 2], t.tanh).to(device)
        self.b_net = MLP([generator_vector, generator_vector * 10, generator_vector * 10, 2], t.tanh).to(device)
        self.optimizers_auct = t.optim.Adam(self.w_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.lrupdate = t.optim.lr_scheduler.StepLR(self.optimizers_auct, 1, gamma=0.9999, last_epoch=-1)
        self.optimizers_auct2 = t.optim.Adam(self.b_net.parameters(), lr=4e-4, betas=(0.9, 0.999))
        self.lrupdate2 = t.optim.lr_scheduler.StepLR(self.optimizers_auct2, 1, gamma=0.9999, last_epoch=-1)

    def update_auct(self):
        low, high, alpha = 0., 2 * np.random.uniform() + 0.5, np.random.uniform()
        distribution = t.tensor([low, high, alpha]).to(device)
        generate_train = Generator(self.args)
        train_data = generate_train.generate_uniform2(low, high)
        ctr_ads = generate_train.generate_uniform(0, 1)
        ctr_og = generate_train.generate_ctr(0, 2)
        loss = -t.mean(self.auct_model('train', train_data, ctr_ads, ctr_og, alpha, self.w_net(distribution), self.b_net(distribution))[0])
        self.optimizers_auct.zero_grad()
        self.optimizers_auct2.zero_grad()
        loss.backward()
        self.optimizers_auct.step()
        self.optimizers_auct2.step()
        self.lrupdate.step()
        self.lrupdate2.step()

def plot_results(losspr2, losspr22, perc2, perc22, lossclick2, losscost2, lossclick22, losscost22, losspr_list, losscost_list, lossclick_list, perc_list, i):

    if i % 200 == 0 and i > 100:
        fixed_point = (np.mean(lossclick2[-24:]), np.mean(losscost2[-24:]))
        fixed_point2 = (np.mean(lossclick22[-24:]), np.mean(losscost22[-24:]))
        print('AMMD (online):',fixed_point)
        print('AMMD (offline):',fixed_point2)
        print('Regret Net (offline):',*zip(*[(np.mean(lossclick_list[k][-24:]), np.mean(losscost_list[k][-24:])) for k in range(8)]))
        print('Regret Net (online):',*zip(*[(np.mean(lossclick_list[k][-24:]), np.mean(losscost_list[k][-24:])) for k in range(8, 16)]))
        print('percentage (AMMD online):',np.mean(perc2[-24:]))
        print('percentage (AMMD offline):',np.mean(perc22[-24:]))
        x_point, y_point = filter_pairs(*zip(*[(np.mean(lossclick_list[k][-24:]), np.mean(losscost_list[k][-24:])) for k in range(8)]))
        x_point2, y_point2 = filter_pairs(*zip(*[(np.mean(lossclick_list[k][-24:]), np.mean(losscost_list[k][-24:])) for k in range(8, 16)]))

        plt.figure(dpi=600)  
        plt.plot(x_point, y_point, marker='o', linestyle='-', color='green', label='Regret Net (offline)')
        plt.plot(x_point2, y_point2, marker='x', linestyle='-', color='blue', label='Regret Net (online)')
        plt.scatter(*fixed_point2, marker='s', color='red', label='AMMD (offline)')
        plt.scatter(*fixed_point, marker='p', color='brown', label='AMMD (online)')
        plt.xlabel('click')
        plt.ylabel('cost')
        plt.legend()
        plt.title('Experiments in dynamic environments')
    
        if not os.path.exists('imgs'):
            os.makedirs('imgs')
        plt.savefig(f'imgs/dynamic_env_step{i}.png')
        
        plt.show()

    if i % 1000 == 0 and i > 100:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5), dpi=600)  
        for j in range(3):
            idx = i -100 - (2-j) * 200
            fixed_point = (np.mean(lossclick2[idx-24:idx]), np.mean(losscost2[idx-24:idx]))
            fixed_point2 = (np.mean(lossclick22[idx-24:idx]), np.mean(losscost22[idx-24:idx]))
            x_point, y_point = filter_pairs(*zip(*[(np.mean(lossclick_list[k][idx-24:idx]), np.mean(losscost_list[k][idx-24:idx])) for k in range(8)]))
            x_point2, y_point2 = filter_pairs(*zip(*[(np.mean(lossclick_list[k][idx-24:idx]), np.mean(losscost_list[k][idx-24:idx])) for k in range(8, 16)]))

            axs[j].plot(x_point, y_point, marker='o', linestyle='-', color='green', label='Regret Net (offline)')
            axs[j].plot(x_point2, y_point2, marker='x', linestyle='-', color='blue', label='Regret Net (online)')
            axs[j].scatter(*fixed_point2, marker='s', color='red', label='AMMD (offline)')
            axs[j].scatter(*fixed_point, marker='p', color='brown', label='AMMD (online)')
            axs[j].set_xlabel('click',size=13)
            axs[j].set_ylabel('cost',size=13)
            axs[j].set_title(f'Traffic {j+1}',size=13)

        handles, labels = axs[-1].get_legend_handles_labels()
        fig.legend(handles, labels, loc='lower center', ncol=4, bbox_to_anchor=(0.5, 0.01), framealpha=1, prop={'size':15})
        plt.tight_layout(rect=[0, 0.1, 1, 1])  

        if not os.path.exists('imgs'):
            os.makedirs('imgs')
        plt.savefig(f'imgs/dynamic_env_steps{i-3000}_to_{i}.png')
        
        plt.show()

def train_linear(args, args2, nets, rollouts):
    generate_train2 = Generator(args2)
    generate_train = Generator(args)
    learner = Learner(args2)
    pid_controllers = [PIDController(0.001, 0.0001, 0.1, 0.5), PIDController(0.05, 0.0002, 1., 0.5)]

    losspr2, losscost2, lossclick2, perc2 = [0], [0], [0], [0]
    losspr22, losscost22, lossclick22, perc22 = [0], [0], [0], [0]
    losspr_list, losscost_list, lossclick_list, perc_list = [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]], [[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0],[0]]
    alpha = 0.5
    alpha2 = 0.5

    for i in range(rollouts):
        if i < 100:
            learner.update_auct()
            low, high= 0., 0.5 + 0.5 * 1
            distribution = t.tensor([low, high, alpha]).float().to(device)
            train_data = generate_train2.generate_uniform2(low, high)
            ctr_ads = generate_train2.generate_uniform(0, 1)
            ctr_og = generate_train2.generate_ctr(0, 2)
            revenue, _, _, utility, cost, click, perc = learner.auct_model('train', train_data, ctr_ads, ctr_og, alpha, learner.w_net(distribution), learner.b_net(distribution))
            print(revenue.cpu().detach().numpy())

            for l, net in enumerate(nets[:8]):
                net.seller_backward(args, generate_train.generate_uniform2(0, 0.5 + 0.5), generate_train.generate_uniform(0, 1), 0.2 * (l) + 0.1, 'train')

            for l, net in enumerate(nets[8:]):
                high = calculate_t(i)
                net.seller_backward(args, generate_train.generate_uniform2(0, high), generate_train.generate_uniform(0, 1), 0.2 * (l) + 0.1, 'train')
        else:
            for _ in range(1):
                learner.update_auct()

            low, high = 0., calculate_t(i)
            distribution = t.tensor([low, high, alpha]).float().to(device)
            train_data = generate_train2.generate_uniform2(low, high)
            ctr_ads = generate_train2.generate_uniform(0, 1)
            ctr_og = generate_train2.generate_ctr(0, 2)
            revenue, _, _, utility, cost, click, perc = learner.auct_model('train', train_data, ctr_ads, ctr_og, alpha, learner.w_net(distribution), learner.b_net(distribution))
            losspr2.append(revenue.cpu().detach().numpy().item())
            losscost2.append(cost.cpu().detach().numpy().item())
            lossclick2.append(click.cpu().detach().numpy().item())
            perc2.append(perc.cpu().detach().numpy().item())
            alpha = alpha * np.exp(pid_controllers[0].update(perc.cpu().detach().numpy().item()))

            high = 1.
            distribution = t.tensor([low, high, alpha]).float().to(device)
            revenue, _, _, utility, cost, click, perc = learner.auct_model('train', train_data, ctr_ads, ctr_og, alpha, learner.w_net(distribution), learner.b_net(distribution))
            losspr22.append(revenue.cpu().detach().numpy().item())
            losscost22.append(cost.cpu().detach().numpy().item())
            lossclick22.append(click.cpu().detach().numpy().item())
            perc22.append(perc.cpu().detach().numpy().item())
            alpha2 = alpha2 * np.exp(pid_controllers[1].update(perc.cpu().detach().numpy().item()))

            fixed_random = np.random.uniform()
            for l, net in enumerate(nets[:8]):
                net.seller_backward(args, generate_train.generate_uniform2(0, 0.5 + 0.5), generate_train.generate_uniform(0, 1), 0.2 * (l) + 0.1, 'train')

            for l, net in enumerate(nets[8:]):
                high = calculate_t(i)
                net.seller_backward(args, generate_train.generate_uniform2(0, high), generate_train.generate_uniform(0, 1), 0.2 * (l) + 0.1, 'train')

            for l, net in enumerate(nets[:16]):
                revenue, _, _, utility, cost, click = net(train_data, ctr_ads, [0.1,0.3,0.5,0.7,0.9,1.2,1.5,2.0,0.1,0.3,0.5,0.7,0.9,1.2,1.5,2.0][l])
                if 0.5 < fixed_random:
                    losspr_list[l].append(revenue.cpu().detach().numpy().item())
                    losscost_list[l].append(cost.cpu().detach().numpy().item())
                    lossclick_list[l].append(click.cpu().detach().numpy().item())
                else:
                    losspr_list[l].append(0.1)
                    losscost_list[l].append(0)
                    lossclick_list[l].append(1.)


        if i % 20 == 0:
            print('i=', i)
        plot_results(losspr2, losspr22, perc2, perc22, lossclick2, losscost2, lossclick22, losscost22, losspr_list, losscost_list, lossclick_list, perc_list, i)

if __name__ == "__main__":
    args = Args((2, 1, "uniform", 10, 10, 100, 100, 1))
    args2 = Args((2, 1, "uniform", 10, 10, 10000, 10000, 1))
    generate_train2 = Generator2(args)
    generate_test2 = Generator2(args)
    train_data = generate_train2.generate_uniform(0, 1)
    test_data = generate_test2.generate_uniform(0, 1)

    nets = [Line_Net(args, train_data, test_data) for _ in range(16)]
    train_linear(args, args2, nets, 200001)
