import argparse
from tqdm import trange, tqdm

from torch.utils.data import DataLoader
from torchmetrics.classification import BinaryAccuracy
import copy
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import os
from tree_data import *
from models import *
from pathlib import Path
import warnings
import wandb
import seaborn as sns

warnings.filterwarnings("ignore")

def get_args():

    parser = argparse.ArgumentParser(
        description='Training Transformers for Bayesian Inference')
    
    # data args
    parser.add_argument(
        '--num-example',
        help='Number of examples for in-context learning',
        type=int,
        default=500)
    parser.add_argument(
        '--num-var',
        help='Number of variables in Markov Chain',
        type=int,
        default=7)
    
    # training args
    parser.add_argument(
        '--steps',
        help='Number of training steps',
        type=int,
        default=4000)
    parser.add_argument('--batch-size', help='Batch size', type=int, default=64)
    
    # model args
    parser.add_argument(
        '--layers',
        help='Number of transformer layers',
        type=int,
        default=12)
    parser.add_argument(
        '--heads',
        help='Number of transformer attention heads',
        type=int,
        default=4)
    parser.add_argument(
        '--hid_dim',
        help='Size of hidden dimension',
        type=int,
        default=128)
    parser.add_argument(
        '--model',
        help='Size of hidden dimension',
        type=str,
        default="transformer")
    
    # log args
    parser.add_argument(
        '--log-every',
        help='log every X steps',
        type=int,
        default=10)
    
    args = parser.parse_args()

    return args
# torch.manual_seed(1111)

def run_test(args):

    if args.wandb:
        run = wandb.init(
            # Set the project where this run will be logged
            project="Transformers-Bayesian-Inference-Test-Tree",
            # Track hyperparameters and run metadata
            config={
                "steps": args.steps,
                "hid_dim": args.hid_dim,
                "heads": args.heads,
                "layers": args.layers,
                "num_example": args.num_example,
                "num variables":args.num_var
            },
        )

    eval_data = {
        "# example":[],
        "prob difference":[],
        "transformer acc":[],
        "graph acc":[],
    }
    metric = BinaryAccuracy().cuda()

    eval_num = 1500
    n_examples = [2, 3, 4, 5, 10, 20, 30, 40, 50, 75, 100]

    model = EncoderTransformer(
        n_dims = args.num_var*3,
        n_embd=args.hid_dim,
        n_layer=args.layers,
        n_head=args.heads)
    model = model.cuda()
  
    log_path = f"ckpt/tree_3/{args.model}_{args.num_example}_{args.layers}_{args.hid_dim}_{args.heads}_{args.batch_size}_seed{args.seed}/"
    ckpt_path = f"{log_path}/ckpt.pt"
    model.load_state_dict(torch.load(ckpt_path, weights_only=True))
    testset = InContextDatasetTest(eval_num, n_examples[-1], mask_out=0, num_var=args.num_var)

    for var_idx in range(args.num_var):

        print("Running Evaluation on Masking Variable:", var_idx)
        eval_data = {
            "# example":[],
            "prob difference":[],
            "transformer acc":[],
            "graph acc":[]
        }
        testset.update_mask_out(var_idx)
        test_loader = DataLoader(testset, batch_size=4, shuffle=False)

        prob_diff = {n:0 for n in n_examples}
        model_cor = {n:0 for n in n_examples}
        graph_cor = 0

        for x, y, probs, graph_pred in test_loader:

            x, y, probs, graph_pred = x.cuda(), y.type(torch.LongTensor).cuda().squeeze(-1), probs.cuda(), graph_pred.cuda()
            for n in n_examples:
          
                inp_x = torch.cat([x[:, :n, :], x[:, -1, :].unsqueeze(1)], dim=1) 
                output = model(inp_x)
                pred_prob = torch.softmax(output, dim=-1)
                model_pred = pred_prob.argmax(-1)
                model_cor[n] += (metric(model_pred, y)*(x.size(0))).item()
                if n == 10:
                    graph_cor +=  (metric(graph_pred, y)*(x.size(0))).item()
                if var_idx == 0:
                    true_prob = probs[:, 0]
                elif var_idx == 1:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]
                  
                elif var_idx == 2:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]

                elif var_idx in [3, 4]:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]

                elif var_idx in [5, 6]:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]

            
                pred_prob = pred_prob[:, 0]
                prob_diff[n] += (torch.abs(pred_prob - true_prob)).sum().item()

        for n in n_examples:
            eval_data["# example"].append(n)
            eval_data["prob difference"].append(np.sum(prob_diff[n])/eval_num)
            eval_data["transformer acc"].append(np.sum(model_cor[n])/eval_num)
            eval_data["graph acc"].append(graph_cor/eval_num)

        if args.wandb:
            wandb.log({
              f"Var {var_idx} # example": n,
              f"Var {var_idx} prob difference": np.sum(prob_diff[n])/eval_num,
              f"Var {var_idx} transformer acc": np.sum(model_cor[n])/eval_num,
              f"Var {var_idx} graph acc": graph_cor/eval_num
            })
    
        eval_data = pd.DataFrame(eval_data)
        eval_data.to_csv(f"{log_path}/var_{var_idx}_eval.csv", index=False)
    if args.wandb:
        wandb.finish()



def run_test_plot_curve(args):

    eval_data = {
        "# example":[],
        "prob difference":[],
        "transformer acc":[],
        "graph acc":[],
    }
    metric = BinaryAccuracy().cuda()

    eval_num = 1500
    n_examples = 100
    model = EncoderTransformer(
        n_dims = args.num_var*3,
        n_embd=args.hid_dim,
        n_layer=args.layers,
        n_head=args.heads)
    model = model.cuda()

  
    log_path = f"ckpt/tree_3/{args.model}_{args.num_example}_{args.layers}_{args.hid_dim}_{args.heads}_{args.batch_size}_seed{args.seed}/"
    ckpt_path = f"{log_path}/ckpt.pt"
    # step_path = f"{log_path}/step{step}.pt"
    model.load_state_dict(torch.load(ckpt_path, weights_only=True))

    testset = InContextDatasetTest(eval_num, n_examples[-1], mask_out=0, num_var=args.num_var)

    for var_idx in range(args.num_var):
        print("Running Evaluation on Masking Variable:", var_idx)
        eval_data = {
            "step": [],
            "prob difference":[],
            "transformer acc":[],
            "graph acc":[]
        }

        for step in range(0, 20001, 50):
            step_path = f"{log_path}/step{step}.pt"
            model.load_state_dict(torch.load(ckpt_path, weights_only=True))
            eval_data["step"].append(step)
            
            testset.update_mask_out(var_idx)
            test_loader = DataLoader(testset, batch_size=eval_num, shuffle=False)
    
            prob_diff = 0
            model_cor = 0
            graph_cor = 0

            for x, y, probs, graph_pred in test_loader:

                x, y, probs, graph_pred = x.cuda(), y.type(torch.LongTensor).cuda().squeeze(-1), probs.cuda(), graph_pred.cuda()
              
                inp_x = torch.cat([x[:, :n, :], x[:, -1, :].unsqueeze(1)], dim=1) 
                output = model(inp_x)
                pred_prob = torch.softmax(output, dim=-1)
                model_pred = pred_prob.argmax(-1)
                model_cor += (metric(model_pred, y)*(x.size(0))).item()
                graph_cor +=  (metric(graph_pred, y)*(x.size(0))).item()
                if var_idx == 0:
                    true_prob = probs[:, 0]
                elif var_idx == 1:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]
                  
                elif var_idx == 2:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]

                elif var_idx in [3, 4]:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]

                elif var_idx in [5, 6]:
                    vars = x[:, -1, :2].view(x.size(0), -1, 2).argmax(-1)
                    first_var = vars[:, 0].long()
                    row_indices = torch.arange(probs.size(0))
                    true_prob = probs[row_indices, first_var][:, 0]
                pred_prob = pred_prob[:, 0]
                prob_diff += (torch.abs(pred_prob - true_prob)).mean().item()

        eval_data["prob difference"].append(prob_diff)
        eval_data["transformer acc"].append(model_cor/eval_num)
        eval_data["graph acc"].append(graph_cor/eval_num)
    
        eval_data = pd.DataFrame(eval_data)
        eval_data.to_csv(f"{log_path}/var_{var_idx}_eval_curve.csv", index=False)

    sns.lineplot(
                data=eval_data, x="step", y="transformer acc", label="Transformer", color="green", alpha=0.8
            )
    sns.lineplot(
                data=eval_data, x="step", y="graph acc", label="Optimal", color="red", alpha=0.8
            )

    plt.xlim(5, 200) 
    plt.xlabel("Step")
    plt.ylabel("Test Accuracy")
    plt.title(f"Test Accuracy Curve") # You can comment this line out if you don't need title
    plt.tight_layout()
    plt.savefig(f"ckpt/tree_3/acc_curves_{var_idx}.png")
    plt.clf()

    sns.lineplot(
                data=eval_data, x="step", y="prob difference", label="Transformer", color="green", alpha=0.8
            )
  
    plt.xlim(5, 200) 
    plt.xlabel("Step")
    plt.ylabel("Prob. Difference")
    plt.title(f"Test Prob. Difference Curve") # You can comment this line out if you don't need title
    plt.tight_layout()
    plt.savefig(f"ckpt/tree_3/prob_diff_curves_{var_idx}.png")
    plt.clf()


# args = get_args()
# run(args)
