from argparse import ArgumentParser
from copy import deepcopy
import os
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils import (
    get_2d_tasks,
    get_model,
    get_sine_angle,
    get_orthogonal_line,
    get_predictions_xyz,
)
from modules import (
    LogisticLoss
)


def main(args):
    save_dir = os.path.join(args.result_path, args.task_order, args.task_sampling, args.model)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    ## Get datasets (3 tasks) ##
    torch.manual_seed(args.seed)
    X_task, y_task, X_tot, y_tot = get_2d_tasks(
        args.data_per_task, 
        scaler=args.data_scaler, 
        device=args.device, 
        include_supp=(args.task_sampling!='online'),
        make_less_separable=(args.model!='Linear')
    )
    eps = args.data_scaler/0.3
    x1min, x1max = X_tot[:,0].min().item()-eps,  X_tot[:,0].max().item()+eps
    x2min, x2max = X_tot[:,1].min().item()-eps,  X_tot[:,1].max().item()+eps
    num_task = len(X_task)
    if args.task_order == 'cyclic':
        task_order = list(range(num_task)) * (args.num_stages // num_task) + list(range(num_task))[:args.num_stages % num_task]
    elif args.task_order == 'random':
        task_order = torch.randint(0, num_task, size=(args.num_stages,)).tolist()

    # True linear model
    if args.model == 'Linear':
        true_joint_max_margin = torch.tensor([1,1]).to(args.device)        
    
    ## Loss, model, optimizer ##
    criterion = LogisticLoss()
    model_continual = get_model(args.model).to(args.device)
    if args.model == 'Linear':
        model_continual.model[0].weight.data.zero_()
    model_joint = deepcopy(model_continual)
    optimizer_continual = torch.optim.SGD(model_continual.parameters(), lr=args.lr)
    optimizer_joint = torch.optim.SGD(model_joint.parameters(), lr=args.lr)

    ## Logging ##
    losses_continual = [[] for _ in X_task]
    losses_joint = [[] for _ in X_task]
    total_losses_continual = []
    total_losses_joint = []
    if args.model == 'Linear':
        ## Trajectory in param space ##
        trajectory_continual = [model_continual.get_param()]
        trajectory_joint = [model_joint.get_param()]
        ## Sine Angles: the lower the closer
        sine_angles_continual = []  
        sine_angles_joint = []
    
    for t, task_id in enumerate(task_order):
        X, y = X_task[task_id], y_task[task_id]

        ## Continual Learning ##
        pbar = tqdm(range(args.K), desc=f"Continual: Stage {t}/{args.num_stages} Task {task_id}/{len(X_task)}")
        for _ in pbar:
            optimizer_continual.zero_grad()
            loss = criterion(model_continual(X), y)
            loss.backward()
            optimizer_continual.step()
        with torch.no_grad():
            for X_val, y_val, loss_arr in zip(X_task, y_task, losses_continual):
                loss_arr.append(criterion(model_continual(X_val), y_val).detach().cpu().item())
            total_losses_continual.append(criterion(model_continual(X_tot), y_tot).detach().cpu().item())
            if args.model == 'Linear':
                trajectory_continual.append(model_continual.get_param())
        
        ## Joint Training ##
        pbar = tqdm(range(args.K), desc=f"Joint: Stage {t}/{args.num_stages}")
        for _ in pbar:
            optimizer_joint.zero_grad()
            loss = criterion(model_joint(X_tot), y_tot)
            loss.backward()
            optimizer_joint.step()
        with torch.no_grad():
            for X_val, y_val, loss_arr in zip(X_task, y_task, losses_joint):
                loss_arr.append(criterion(model_joint(X_val), y_val).detach().cpu().item())
            total_losses_joint.append(criterion(model_joint(X_tot), y_tot).detach().cpu().item())
            if args.model == 'Linear':
                trajectory_joint.append(model_joint.get_param())

        if args.model == 'Linear':
            ## Sine Angle ##
            sine_angles_continual.append(get_sine_angle(model_continual, true_joint_max_margin))
            sine_angles_joint.append(get_sine_angle(model_joint, true_joint_max_margin))
        elif args.model == 'ReLU':
            ## Decision Boundary ##
            x1_c, x2_c, y_c = get_predictions_xyz(model_continual, x1min, x1max, x2min, x2max, device=args.device)
            x1_j, x2_j, y_j = get_predictions_xyz(model_joint, x1min, x1max, x2min, x2max, device=args.device)
            

        ## Visualization ##
        colors = ['tab:blue', 'tab:orange', 'tab:pink']
        fig, ax = plt.subplots(1,1, figsize=(10,5))
        if args.model == 'Linear':
            ## Plot trajectories
            ax.plot([0,max(x1max,x2max)],[0,max(x1max,x2max)], c='gray', alpha=0.5, linewidth=1.5, linestyle='-.', label="Joint Max-margin Direction")
            ax.plot([min(x1min,-x2max),max(x1max,-x2min)],[max(-x1min,x2max),min(-x1max,x2min)], c='gray', alpha=0.3, linewidth=2, linestyle='--')
            ax.plot(*torch.stack(trajectory_joint, dim=1).cpu(), marker='.', c='tab:red', alpha=0.8, linewidth=2.5, label="Joint Training trajectory")
            ax.plot(*torch.stack(trajectory_continual, dim=1).cpu(), marker='.', c='tab:green', alpha=0.5, linewidth=2.5, label="Continual Learning trajectory")
            ax.plot(*get_orthogonal_line(trajectory_joint[-1].cpu()),  c='tab:red', alpha=0.8, linestyle='--', linewidth=2)
            ax.plot(*get_orthogonal_line(trajectory_continual[-1].cpu()),  c='tab:green', alpha=0.5, linestyle='--', linewidth=2)
        elif args.model == 'ReLU':
            ## contour plots of decision boundaries
            cm = plt.cm.RdBu
            ax.contourf(x1_j, x2_j, y_j, 10, cmap=cm, alpha=0.25)
            ax.contourf(x1_c, x2_c, y_c, 10, cmap=cm, alpha=0.25)
            ax.contour(x1_j, x2_j, y_j, 10, colors='tab:red', linestyles='dotted', linewidths=0.5)
            ax.contour(x1_c, x2_c, y_c, 10, colors='tab:green', linestyles='dotted', linewidths=0.5)
            ax.contour(x1_j, x2_j, y_j, 0, colors='tab:red', linestyles='--', linewidths=2, alpha=0.8)
            ax.contour(x1_c, x2_c, y_c, 0, colors='tab:green', linestyles='--', linewidths=2, alpha=0.5)
            ax.plot([x1min-2, x1min-2], [0, 1], c='tab:red', linestyle='--', linewidth=2, label="Joint Training\nDecision Boundary")
            ax.plot([x1min-2, x1min-2], [0, 1], c='tab:green', linestyle='--', linewidth=2, label="Continual Learning\nDecision Boundary")
        for i, (X, y, c) in enumerate(zip(X_task, y_task, colors)):
            ax.scatter(*X[y.view(-1) ==  1].T.cpu(), marker='o', s=50, c=c, edgecolors='k', linewidths=0.5, label=f"Task {i} Data ($y=+1$)")
            ax.scatter(*X[y.view(-1) == -1].T.cpu(), marker='X', s=50, c=c, edgecolors='k', linewidths=0.5, label=f"Task {i} Data ($y=-1$)")
        ax.set_xlim(x1min, x1max)
        ax.set_ylim(x2min, x2max)
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        ax.set_aspect('equal')
        ax.xaxis.set_major_locator(plt.NullLocator())
        ax.yaxis.set_major_locator(plt.NullLocator())
        fig.tight_layout()
        fig.savefig(os.path.join(save_dir, f"{args.task_order}_{args.task_sampling}_{args.model.lower()}_2d_param.pdf"), bbox_inches="tight")
        if args.model == 'ReLU' and t+1 in [1,2,3, 10, 50, 100, args.num_stages]:
            fig.savefig(os.path.join(save_dir, f"{args.task_order}_{args.task_sampling}_{args.model.lower()}_2d_param_t{t+1}.pdf"), bbox_inches="tight")
        plt.close(fig)
        
        fig, ax = plt.subplots(1, 1)
        for m in range(len(X_task)):
            ax.plot(range(1, t+2),losses_continual[m], c=colors[m], label=f"task{m} loss")
        ax.plot(range(1, t+2), total_losses_continual, c='tab:red', label=f"total loss")
        ax.set_yscale('log')
        ax.set_xlabel("Stage")
        ax.set_ylabel("Loss")
        ax.legend()
        ax.set_title("Loss: Continually learned model")
        fig.tight_layout()
        fig.savefig(os.path.join(save_dir, f"{args.task_order}_{args.task_sampling}_{args.model.lower()}_2d_loss.pdf"), bbox_inches="tight")
        plt.close(fig)
                    
        fig, ax = plt.subplots(1, 1)
        for m in range(len(X_task)):
            ax.plot(range(1, t+2), losses_joint[m], c=colors[m], label=f"task{m} loss")
        ax.plot(range(1, t+2), total_losses_joint, c='tab:red', label=f"total loss")
        ax.set_yscale('log')
        ax.set_xlabel("Stage")
        ax.set_ylabel("Loss")
        ax.legend()
        ax.set_title("Loss: Jointly learned model")
        fig.tight_layout()
        fig.savefig(os.path.join(save_dir, f"{args.task_order}_{args.task_sampling}_{args.model.lower()}_2d_loss_joint.pdf"), bbox_inches="tight")
        plt.close(fig)
                                
        if args.model=='Linear':
            fig, axes = plt.subplots(1,1)
            ax = axes
            ax.plot(range(1, t+2), sine_angles_continual, marker='.', c='tab:green', label="Continually learned model")
            ax.plot(range(1, t+2), sine_angles_joint, marker='.', c='tab:red', label="Jointly trained model")
            ax.set_xscale('log')
            ax.set_yscale('log')
            ax.set_xlabel("Stage")
            ax.set_ylabel("Sine Angle")
            ax.set_ylim(8e-3, None)
            ax.legend()
            ax.set_title("Sine Angle($=\\sqrt{{(1-{{\\rm cossim}}^2)}}$)\nbetween Linear Model vs. Joint Max-Margin")
            fig.tight_layout()
            fig.savefig(os.path.join(save_dir, f"{args.task_order}_{args.task_sampling}_{args.model.lower()}_2d_angle.pdf"), bbox_inches="tight")
            plt.close(fig)

        if args.task_sampling == 'online':
            ## Re-sample the Datasets
            X_task, y_task, X_tot, y_tot = get_2d_tasks(
                args.data_per_task, 
                scaler=args.data_scaler, 
                device=args.device, 
                include_supp=(args.task_sampling!='online'),
                make_less_separable=(args.model!='Linear')
            )
    

if __name__ == "__main__":
    """ 2D 3-task Continual Linear Classification """

    parser = ArgumentParser()
    parser.add_argument('--result_path', type=str, default='./2D_results', help='path of result to be saved')
    parser.add_argument('--model', type=str, default='Linear', choices=['ReLU', 'Linear'])
    parser.add_argument('--task_order', type=str, default='cyclic', choices=['cyclic', 'random'])
    parser.add_argument('--task_sampling', type=str, default='once', choices=['once', 'online'])
    parser.add_argument('--data_per_task', type=int, default=100, help="number of data points per task")
    parser.add_argument('--num_stages', type=int, default=300, help="number of stages")
    parser.add_argument('--K', type=int, default=1000, help='iter. per stage')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--data_scaler', type=float, default=0.3, help="scaler multiplied to every x's")
    parser.add_argument('--seed', type=int, default=2024, help='random seed')
    parser.add_argument('--device', type=str, default='cpu', help='cpu / cuda:#')

    args = parser.parse_args()

    main(args)