from argparse import ArgumentParser
from copy import deepcopy
import os
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from utils import (
    get_3d_toy_tasks,
    get_model,
)
from modules import (
    LogisticLoss
)


def main(args):
    save_dir = os.path.join(args.result_path, args.task_order)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    ## Get datasets (3D, 2 tasks) ##
    torch.manual_seed(args.seed)
    X_task, y_task, X_tot, y_tot = get_3d_toy_tasks(device=args.device)
    num_task = len(X_task)
    if args.task_order == 'cyclic':
        task_order = list(range(num_task)) * (args.num_stages // num_task) + list(range(num_task))[:args.num_stages % num_task]
    elif args.task_order == 'random':
        task_order = torch.randint(0, num_task, size=(args.num_stages,)).tolist()

    ## Loss, model, optimizer ##
    criterion = LogisticLoss()
    model_continual = get_model('Linear', hidden_dims=[3, 1]).to(args.device)
    model_continual.model[0].weight.data.zero_()  # init at (0,0,0)
    model_joint = deepcopy(model_continual)
    model_alone = [deepcopy(model_continual), deepcopy(model_continual)]  # for task 1 and 2
    optimizer_continual = torch.optim.SGD(model_continual.parameters(), lr=args.lr)
    optimizer_joint = torch.optim.SGD(model_joint.parameters(), lr=args.lr)
    optimizer_alone = [torch.optim.SGD(m.parameters(), lr=args.lr) for m in model_alone]

    ## Logging ##
    losses_continual = [[] for _ in X_task]
    losses_joint = [[] for _ in X_task]
    total_losses_continual = []
    total_losses_joint = []
    ## Trajectory in param space ##
    trajectory_continual = [model_continual.get_param()]
    trajectory_joint = [model_joint.get_param()]
    trajectory_alone = [[m.get_param()] for m in model_alone]


    for t, task_id in enumerate(task_order):
        X, y = X_task[task_id], y_task[task_id]

        ## Continual Learning ##
        pbar = tqdm(range(args.K), desc=f"Continual: Stage {t}/{args.num_stages} Task {task_id}/{len(X_task)}")
        for _ in pbar:
            optimizer_continual.zero_grad()
            loss = criterion(model_continual(X), y)
            loss.backward()
            optimizer_continual.step()
        with torch.no_grad():
            for X_val, y_val, loss_arr in zip(X_task, y_task, losses_continual):
                loss_arr.append(criterion(model_continual(X_val), y_val).detach().cpu().item())
            total_losses_continual.append(criterion(model_continual(X_tot), y_tot).detach().cpu().item())
            trajectory_continual.append(model_continual.get_param())
        
        ## Joint Training ##
        pbar = tqdm(range(args.K), desc=f"Joint: Stage {t}/{args.num_stages}")
        for _ in pbar:
            optimizer_joint.zero_grad()
            loss = criterion(model_joint(X_tot), y_tot)
            loss.backward()
            optimizer_joint.step()
        with torch.no_grad():
            for X_val, y_val, loss_arr in zip(X_task, y_task, losses_joint):
                loss_arr.append(criterion(model_joint(X_val), y_val).detach().cpu().item())
            total_losses_joint.append(criterion(model_joint(X_tot), y_tot).detach().cpu().item())
            trajectory_joint.append(model_joint.get_param())

        ## Individual Training ##
        for m, (X, y, model, opt, traj) in enumerate(zip(X_task, y_task, model_alone, optimizer_alone, trajectory_alone)):
            pbar = tqdm(range(args.K), desc=f"Task {m} only: Stage {t}/{args.num_stages}")
            for _ in pbar:
                opt.zero_grad()
                loss = criterion(model(X), y)
                loss.backward()
                opt.step()
            with torch.no_grad():
                traj.append(model.get_param())
        
        ## Visualization ##
        fig, ax = plt.subplots(
            1,1,figsize=(6,4),
            subplot_kw={'projection': '3d'}
        )

        ax.scatter([0], [0], [0], s=20, c='k', marker='$\\rm o$', label=f"Init = (0,0,0)" if t==0 else None)

        for m, (X, y) in enumerate(zip(X_task, y_task)):
            ax.scatter(*X.T, marker='P', s=50, alpha=0.5, label=f'Task {m} datapoints')

        for m, (traj, c) in enumerate(zip(trajectory_alone, ['tab:blue', 'tab:orange'])):
            ax.plot(*torch.stack(traj, dim=1).cpu(), marker='.', c=c, alpha=0.5, label=f'Task {m} Only')

        ax.plot(
            *torch.stack(trajectory_joint, dim=1).cpu(), 
            marker='.', c='r', alpha=0.5, label=f'Joint Training'
        )
        ax.plot(
            *torch.stack(trajectory_continual, dim=1).cpu(), 
            marker='.', c='tab:green', alpha=0.5, label=f'Continual ({args.task_order.capitalize()})'
        )
        ax.view_init(elev=15, azim=15, roll=0)
        ax.xaxis.set_ticklabels([])
        ax.yaxis.set_ticklabels([])
        ax.zaxis.set_ticklabels([])
        ax.legend()

        fig.subplots_adjust(bottom=-0.1)
        fig.savefig(os.path.join(save_dir, f"3d_toy_param.pdf"), bbox_inches='tight')
        plt.close(fig)
        


if __name__ == "__main__":
    """ 3D 2-task Continual Linear Classification """

    parser = ArgumentParser()
    parser.add_argument('--result_path', type=str, default='./3D_results', help='path of result to be saved')
    parser.add_argument('--task_order', type=str, default='cyclic', choices=['cyclic', 'random'])
    parser.add_argument('--num_stages', type=int, default=200, help="number of stages")
    parser.add_argument('--K', type=int, default=1000, help='iter. per stage')
    parser.add_argument('--lr', type=float, default=0.1, help='learning rate')
    parser.add_argument('--seed', type=int, default=2024, help='random seed')
    parser.add_argument('--device', type=str, default='cpu', help='cpu / cuda:#')

    args = parser.parse_args()

    main(args)
