import torch
import torch.nn
import torch.nn.functional as F
from IPython import embed
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import time
import argparse
import torch.nn.functional as F
import time
import os
import pickle
from dataset import TrajDataset
from net import Net, Transformer, TransformerTall, TransformerBERT, TransformerTall2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if __name__ == '__main__':
    if not os.path.exists('figs/loss'):
        os.makedirs('figs/loss', exist_ok=True)
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)

    import argparse

    parser = argparse.ArgumentParser()
    
    parser.add_argument("--envs", type=int, required=False, default=1000, help="Envs")
    parser.add_argument("--hists", type=int, required=False, default=1, help="Histories")
    parser.add_argument("--samples", type=int, required=False, default=1, help="Samples")
    parser.add_argument("--H", type=int, required=False, default=10, help="Context horizon")
    parser.add_argument("--embd", type=int, required=False, default=32, help="Embedding")
    parser.add_argument("--head", type=int, required=False, default=1, help="Embedding")
    parser.add_argument("--layer", type=int, required=False, default=3, help="Embedding")
    parser.add_argument("--lr", type=float, required=False, default=1e-3, help="Dimension")
    parser.add_argument("--dim", type=int, required=False, default=1, help="Dimension")
    parser.add_argument("--k", type=int, required=False, default=1, help="Top K value")
    parser.add_argument("--orig", type=int, required=False, default=2, help="Orig")
    parser.add_argument("--opt", type=int, required=False, default=0, help="Optimizer type")
    parser.add_argument("--dropout", type=float, required=False, default=0, help="Dropout")
    parser.add_argument("--var", type=float, required=False, default=0.0, help="Variance")
    parser.add_argument("--cov", type=float, required=False, default=0.0, help="Coverage")
    parser.add_argument("--trans", type=int, required=False, default=0, help="Transformer type")
    parser.add_argument("--env", type=str, required=True, help="Environment")
    parser.add_argument("--lin_d", type=int, required=False, default=1, help="Linear Dimension")

    parser.add_argument('--full', default=False, action='store_true')
    parser.add_argument('--shuffle', default=False, action='store_true')


    args = vars(parser.parse_args())
    print("Args:")
    print(args)

    env = args['env']
    n_envs = args['envs']
    n_hists = args['hists']
    n_samples = args['samples']
    H = args['H']
    dim = args['dim']
    dx = dim
    du = dim
    n_embd = args['embd']
    n_head = args['head']
    n_layer = args['layer']
    useQ = False 
    use_net = False 
    lr = args['lr']
    shuffle = args['shuffle']
    full = args['full']
    opt = args['opt']
    dropout = args['dropout']
    var = args['var']
    trans = args['trans']
    cov = args['cov']
    topk = args['k']
    orig = args['orig']
    lin_d = args['lin_d']
    warm_start = False

    distr = n_envs < 0 # decide whether to just directly sample from the distribution

    if distr:       EPOCHS = 90000
    else:           EPOCHS = 3000

    if env == 'bandit':
        dx = 1
        prefix = 'bandit'
        path_train = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_orig{orig}_train.pkl'
        path_test =  f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_orig{orig}_test.pkl'
        filename = f'{env}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_cov{cov}_orig{orig}_H{H}_d{dim}'

    elif env == 'bandit_thompson':
        dx = 1
        prefix = 'bandit_thompson'
        path_train = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_train.pkl'
        path_test =  f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_test.pkl'
        filename = f'{env}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_cov{cov}_H{H}_d{dim}'

    elif env == 'bandit_ood':
        dx = 1
        prefix = 'bandit_ood'
        path_train = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_train.pkl'
        path_test =  f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_cov{cov}_test.pkl'
        filename = f'{env}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_cov{cov}_H{H}_d{dim}'

    elif env == 'bandit_topk':
        dx = 1
        prefix = 'bandit_topk'
        path_train = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_k{topk}_train.pkl'
        path_test =  f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_var{var}_k{topk}_test.pkl'
        filename = f'{env}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_k{topk}_H{H}_d{dim}'

    elif env == 'linear_bandit':
        dx = 1
        prefix = 'linear_bandit'
        path_train = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_dlin{lin_d}_var{var}_ws{warm_start}_train.pkl'
        path_test =  f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_dlin{lin_d}_var{var}_ws{warm_start}_test.pkl'
        filename = f'{env}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_var{var}_H{H}_d{dim}_dlin{lin_d}_ws{warm_start}'

    elif env in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
        prefix = env
        dx = 2
        du = 5
        path_train = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
        path_test = f'datasets/trajs_{env}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'
        filename = f'{env}_trans{trans}_full{full}_shuf{shuffle}_opt{opt}_lr{lr}_do{dropout}_embd{n_embd}_layer{n_layer}_head{n_head}_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}'

    else:
        raise NotImplementedError





    batch_size = 64
    config = {
        'shuffle': shuffle,
        'distr': distr
    }
    sampler_config = {
        'bsize': batch_size,
        'n_hists': n_hists,
        'n_samples': n_samples,
        'H': H,
        'dim': dim,
        'var': var,
        'cov': cov,
    }
    
    


    config = {
        'H': H,
        'dx': dx,
        'du': du,
        'n_layer': n_layer,
        'n_embd': n_embd,
        'n_head': n_head,
        'Q': useQ,
        'shuffle': shuffle,
        'full': full,
        'dropout': dropout,
    }

    if use_net:                 model = Net(config).to(device)
    elif trans == 0:            model = Transformer(config).to(device)
    elif trans == 1:            model = TransformerTall(config).to(device)
    elif trans == 2:            model = TransformerTall2(config).to(device)
    else:                       model = TransformerBERT(config).to(device)
    
    params = {'batch_size': batch_size,
            'shuffle': True}

    if not distr:
        ds_train = TrajDataset(path_train, config)
        ds_test = TrajDataset(path_test, config)

        n_train = len(ds_train)
        n_test = len(ds_test)

        
        train_loader = torch.utils.data.DataLoader(ds_train, **params)
        test_loader = torch.utils.data.DataLoader(ds_test, **params)
    else:
        sampler = BanditSampler(sampler_config)


    if opt == 0:    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    elif opt == 1:  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:           raise NotImplementedError
    
    # if bandit:      loss_fn = nn.CrossEntropyLoss(reduction='sum')
    # else:           loss_fn = nn.CrossEntropyLoss(reduction='sum')
    
    if env in ['bandit', 'bandit_topk', 'bandit_ood', 'bandit_thompson', 'linear_bandit']:
        loss_fn = nn.CrossEntropyLoss(reduction='sum')
    elif env in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
        loss_fn = nn.CrossEntropyLoss(reduction='sum')
    else:
        raise NotImplementedError

    test_loss = []
    train_loss = []


    for epoch in range(EPOCHS):
        print(f"Epoch: {epoch}")

        if not distr:
            with torch.no_grad():
                epoch_test_loss = 0.0

                for i, batch in enumerate(test_loader):

                    true_actions = batch['actions']
                    pred_actions = model(batch)

                    if full:    true_actions = true_actions.unsqueeze(1).repeat(1, pred_actions.shape[1], 1)
    
                    if env in ['bandit', 'bandit_ood', 'bandit_thompson', 'linear_bandit']:
                        loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
                    elif env == 'bandit_topk':
                        loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
                    elif env in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
                        loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
                    else:
                        raise NotImplementedError

                    epoch_test_loss += loss.item() / H


                test_loss.append(epoch_test_loss / n_test)
                print(f'Test Loss:        {test_loss[-1]}')


            epoch_train_loss = 0.0
            start_time = time.time()


            for i, batch in enumerate(train_loader):

                true_actions = batch['actions']
                pred_actions = model(batch)

                if full:    true_actions = true_actions.unsqueeze(1).repeat(1, pred_actions.shape[1], 1)

                optimizer.zero_grad()
                if env in ['bandit', 'bandit_ood', 'bandit_thompson', 'linear_bandit']:
                    loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
                elif env == 'bandit_topk':
                    loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
                elif env in ['darkroom', 'darkroom_heldout', 'darkroom_stitch']:
                    loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
                else:
                    raise NotImplementedError
                loss.backward()
                optimizer.step()
                epoch_train_loss += loss.item() / H
            
            end_time = time.time()
            diff = end_time - start_time
            train_loss.append(epoch_train_loss / n_train)
            print(f'Train Loss:       {train_loss[-1]}')
            print(f'Batch time:       {diff}\n\n')
            

        else:
            start_time = time.time()
            batch = sampler.sample()
            true_actions = batch['actions']
            pred_actions = model(batch)

            if full:    true_actions = true_actions.unsqueeze(1).repeat(1, pred_actions.shape[1], 1)

            optimizer.zero_grad()
            if bandit:  loss = loss_fn(pred_actions.reshape(-1, du), true_actions.reshape(-1, du))
            else:       loss = loss_fn(pred_actions, true_actions)
            loss.backward()
            optimizer.step()
            epoch_train_loss = loss.item() / H / batch_size

            train_loss.append(epoch_train_loss)
            end_time = time.time()
            diff = end_time - start_time

            print(f'Train Loss:       {train_loss[-1]}')
            print(f'Batch time:       {diff}\n\n')

        # LOGGING
        if (not distr and ((epoch + 1) % 50 == 0)) or (distr and ((epoch + 1) % 500 == 0)):
            
            # check if test loss is best so far
            torch.save(model.state_dict(), f'models/{filename}_epoch{epoch+1}.pt')
        
        if (epoch + 1) % 10 == 0:
            plt.yscale('log')
            plt.plot(train_loss[1:], label="Train final")
            plt.plot(test_loss[1:], label="Test final")
            plt.legend()
            plt.savefig(f"figs/loss/{filename}_train_loss.png")
            plt.clf()

    torch.save(model.state_dict(), f'models/{filename}.pt')
    print("Done.")
