import argparse
from tqdm import trange, tqdm

from torch.utils.data import DataLoader
from torchmetrics.classification import BinaryAccuracy

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import os
from tree_data import *
from models import *
import wandb
from test_tree import run_test

from pathlib import Path
import warnings
warnings.filterwarnings("ignore")


torch.set_num_threads(5)

parser = argparse.ArgumentParser(
    description='Training Transformers for Bayesian Inference')

# data args
parser.add_argument(
    '--num-example',
    help='Number of examples for in-context learning',
    type=int,
    default=100)
parser.add_argument(
    '--num-var',
    help='Number of variables in the Tree',
    type=int,
    default=7)
parser.add_argument(
    '--seed',
    help='seed',
    type=int,
    default=1111)


# training args
parser.add_argument(
    '--steps',
    help='Number of training steps',
    type=int,
    default= 30000)
parser.add_argument(
    '--init_lr',
    help='Initial learning rate',
    type=float,
    default=1e-4)
parser.add_argument('--batch-size', help='Batch size', type=int, default=64)

# model args
parser.add_argument(
    '--layers',
    help='Number of transformer layers',
    type=int,
    default=6)
parser.add_argument(
    '--heads',
    help='Number of transformer attention heads',
    type=int,
    default=8)
parser.add_argument(
    '--hid_dim',
    help='Size of hidden dimension',
    type=int,
    default=256)
parser.add_argument(
    '--model',
    help='model type: [transformer, attn]',
    type=str,
    default="transformer")

# log args
parser.add_argument(
    '--log-every',
    help='log every X steps',
    type=int,
    default=50)
parser.add_argument(
    '--wandb',
    help='model type: [transformer, attn]',
    type=bool,
    default=True)
parser.add_argument(
    '--note',
    help='note',
    type=str,
    default="None")


args = parser.parse_args()


torch.manual_seed(args.seed)

class Curriculum:
    def __init__(self, num_var=7):
        self.cur = 1
        self.last_update = -1
        self.max_interval_update = 3000
        self.acc_threshold = 0.65
        self.max_cur = num_var - 1
  
    def update(self, steps, train_acc):

        if self.cur == self.max_cur:
            return False
      
        if train_acc >= self.acc_threshold:
            self.cur += 1
            self.last_update = steps
            return True
        # elif steps - self.last_update >= self.max_interval_update:
        #     self.last_update = steps
        #     self.cur += 1
        #     return True
        else:
            return False

class EarlyStop:
    def __init__(self):
        self.min_loss = 0.59
        self.threshold = 0.05
        self.max_steps = 2000
        self.last_update = 0

    def update(self, step, loss):
        if loss <= self.min_loss:
            self.min_loss = loss
            self.last_update = step
            return False
        else:
            if loss <= self.min_loss + self.threshold:
                return False
            elif step - self.last_update > self.max_steps:
                return True
            else:
                return False


def lr_decay(optimizer):
    for g in optimizer.param_groups:
        g['lr'] *= 0.95
    return optimizer

def sample_seeds(total_seeds, count):
    seeds = set()
    while len(seeds) < count:
        seeds.add(torch.randint(0, total_seeds - 1))
    return seeds


def train_step(model, optimizer, x, y, loss_func, metric, args):

    x, y = x.cuda(), y.type(torch.LongTensor).cuda().squeeze(-1)
    optimizer.zero_grad()
    output = model(x)
    # print(output.size(), y.size())
    loss = loss_func(output, y)
    loss.backward()
    optimizer.step()

    prediction = torch.softmax(output.detach(), dim=1).argmax(-1)
    acc = metric(prediction, y)

    torch.cuda.empty_cache()

    return acc.item(), loss.item()

def train(args):

    if args.wandb:
        run = wandb.init(
            # Set the project where this run will be logged
            project="Transformers-Bayesian-Inference-Tree3",
            # Track hyperparameters and run metadata
            config={
                "learning_rate": args.init_lr,
                "steps": args.steps,
                "hid_dim": args.hid_dim,
                "heads": args.heads,
                "layers": args.layers,
                "num_example": args.num_example,
                "num var": args.num_var
            },
        )

    trainset = InContextDataset(
        args.steps *
        args.batch_size,
        args.num_example,
        args.num_var)
    train_loader = DataLoader(
        trainset,
        batch_size=64,
        shuffle=True)

    print("Preparing model...")
    if args.model == "transformer":
        model = TransformerModel(
            args.num_var*3,
            n_positions=args.num_example + 1,
            n_embd=args.hid_dim,
            n_layer=args.layers,
            n_head=args.heads)
        model = model.cuda()

    else:
        model = EncoderTransformer(
            n_dims = args.num_var*3,
            n_embd=args.hid_dim,
            n_layer=args.layers,
            n_head=args.heads)
        model = model.cuda()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.init_lr)
    sch = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(args.steps/100))
    print("Preparing metrics...")
    loss_func = torch.nn.CrossEntropyLoss()
    metric = BinaryAccuracy().cuda()

    log = {
        "Training Loss": [],
        "Training Acc.": [],
        "Step": []
    }

    print("Start Training...")
    log_path = f"ckpt/tree_3/{args.model}_{args.num_example}_{args.layers}_{args.hid_dim}_{args.heads}_{args.batch_size}_seed{args.seed}/"
    path = Path(log_path)
    path.mkdir(parents=True, exist_ok=True)

    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    num_correct_train = 0
    curriculum = Curriculum(args.num_var)
    early_stop = EarlyStop()
    p_bar = trange(args.steps+1)
  
    log_path_txt = log_path + "log.txt"
    f = open(log_path_txt, "w")
    best_loss = 999
    cur_steps = 0
    for step in p_bar:

        _i, (x, y) = next(enumerate(train_loader))
        
        model.train()
        train_acc_step, train_loss_step = train_step(
            model, optimizer, x, y, loss_func, metric, args)
        train_loss.append(train_loss_step)
        train_acc.append(train_acc_step)
        cur_steps += 1
        # if step > 15000 and step % 100 == 0:
        #     optimizer = lr_decay(optimizer)
      
        if step != 0 and step % args.log_every == 0:
            # torch.save(model.state_dict(), f"{log_path}/step{step}.pt")

            p_bar.set_description(
                f"| Step: {step} | Train Loss: {str(round(np.mean(train_loss), 4))}")

            log["Training Loss"].append(np.mean(train_loss))
            log["Training Acc."].append(np.mean(train_acc))
            log["Step"].append(step)

            if args.wandb:
                wandb.log({
                  "Training Loss": np.mean(train_loss),
                  "Training Acc.": np.mean(train_acc),
                  "Step": step,
                  "Curriculum": curriculum.cur
                })

            if step - curriculum.last_update > 1000:
                optimizer = lr_decay(optimizer)
          
            if np.mean(train_loss) <= best_loss and curriculum.cur == args.num_var-1:
                ckpt_path = f"{log_path}/ckpt.pt"
                torch.save(model.state_dict(), ckpt_path)
                best_loss = np.mean(train_loss)

            if curriculum.update(step, np.mean(train_acc)):
                trainset.mask_out_range = [0, curriculum.cur]
                train_loader = DataLoader(
                    trainset,
                    batch_size=args.batch_size,
                    shuffle=True)
          
            log_path_txt = log_path + "log.txt"
            f = open(log_path_txt, "a")
            f.write("|Step|" +
                    "\t" +
                    str(step) +
                    "\t" +
                    "|Train Loss:|" +
                    "\t" +
                    str(round(np.mean(train_loss), 4)) +
                    "\t" +
                    "|Train Acc:|" +
                    "\t" +
                    str(round(np.mean(train_acc), 4)) +
                    " Curriculum: "
                    +
                    str(curriculum.cur)
                    +
                    "\n")
            f.write("\n")

            train_loss = []
            train_acc = []
            num_correct_train = 0
  
    if not os.path.exists(f"{log_path}/ckpt.pt"):
        ckpt_path = f"{log_path}/ckpt.pt"
        torch.save(model.state_dict(), ckpt_path)
  
    df = pd.DataFrame(log)
    df.to_csv(f"{log_path}/logs.csv", index=False)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    sns.lineplot(
        data=df,
        x="Step",
        y="Training Loss",
        ax=axes[0],
        color="orange",
        alpha=0.8,
        label="Train")

    sns.lineplot(
        data=df,
        x="Step",
        y="Training Acc.",
        ax=axes[1],
        color="orange",
        alpha=0.8,
        label="Train")
  
    plt.title(f"Steps ({args.steps}), ICL Examples ({args.num_example})")
    plt.tight_layout()
    plt.savefig(f"{log_path}/curve_plot.png")

    if args.wandb:
        wandb.finish()

train(args)
run_test(args)