from collections import defaultdict
import os
import json
import random
import math
import sys
import statistics
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from continuum.metrics import Logger


from tqdm import tqdm
from utils import get_class_order, get_classes_names_cur_task, customed_forgetting, cosine_lr
from utils.utils import seed_worker

from dataset.datasets_scenarios import build_cl_scenarios
import clip.clip as tune_clip
from clip.tuner import Tuner


def run_clip_tuner_incremental(cfg, device):
    
    # load model
    model, transforms_train = tune_clip.load("ViT-B/16", device=device, jit=False)
    for p in model.parameters():
        p.requires_grad = False
        
    tuner = Tuner(model, use_image_tuner=True, use_text_tuner=True).to(device)

    cfg.class_order = get_class_order(cfg, file_name=None)
    eval_scenarios, classes_names = build_cl_scenarios(
        cfg, is_train=False, transforms=transforms_train
    )

    train_scenarios, _ = build_cl_scenarios(
        cfg, is_train=True, transforms=transforms_train
    )
    
    
    classnames_till_now = []
    
    acc_list_image = []
    metric_logger_image = Logger(list_subsets=["test"])
    
    for task_id, train_dataset in enumerate(train_scenarios):
        evalset = eval_scenarios[:task_id + 1]
        
        classesnames_cur = get_classes_names_cur_task(task_id, classes_names, cfg.initial_increment, cfg.increment)
        print(classesnames_cur)        
        texts_train = [cfg.prompt_template.format(c) for c in classesnames_cur]
        texts_train = tune_clip.tokenize(texts_train).to(device)
        
        classnames_till_now += classesnames_cur
        text_tokens = tune_clip.tokenize(
            [cfg.prompt_template.format(c) for c in classnames_till_now]
        ).to(device)

        # stage1 fine-tuning text encoder
        print("#####################################################################")
        print(f"##### Task {task_id} stage1: trainset lenght:{len(train_dataset)}, evalset lenght:{len(evalset)}. #####")
        print("#####################################################################")
        
        train_text(model, tuner, task_id, cfg, train_dataset, texts_train)      

                
        # stage2 fine-tuning image encoder
        print("#####################################################################")
        print(f"##### Task {task_id} stage2: trainset lenght:{len(train_dataset)}, evalset lenght:{len(evalset)}. #####")
        print("#####################################################################")

        train_image(model, tuner, task_id, cfg, train_dataset, texts_train)

        test(cfg, task_id, acc_list_image, model, tuner, evalset, metric_logger_image, text_tokens, device, cfg.log_path_image)
    
    with open(cfg.log_path_image, 'a+') as f:
        f.write(json.dumps({
            'last_image': round(acc_list_image[-1], 2), 
            'avg_image': round(statistics.mean(acc_list_image), 2)
        }) + '\n')
    
    print(f"Last_image: {round(acc_list_image[-1], 2)}, Avg_image: {round(statistics.mean(acc_list_image), 2)}")

def test(cfg, task_id, acc_list, model, tuner, evalset, metric_logger, text_tokens, device, log_path):
    model.eval()
    tuner.eval()
    eval_loader = DataLoader(evalset, batch_size=64, num_workers=cfg.workers)
    
    for inputs, targets, task_ids in tqdm(eval_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            logits_per_image, _ = model(inputs, text_tokens, tuner)
            probs = logits_per_image.softmax(dim=-1)
        metric_logger.add([probs.cpu().argmax(dim=1), targets.cpu(), task_ids], subset="test")
    
    acc = round(100 * metric_logger.accuracy, 2)
    avg_acc = round(100 * metric_logger.average_incremental_accuracy, 2)
    acc_per_task = [round(100 * acc_t, 2) for acc_t in metric_logger.accuracy_per_task]
    
    all_preds, all_targets, task_ids = metric_logger._get_best_epochs(subset="test")
    forgetting = customed_forgetting(all_preds, all_targets, task_ids)
    forgetting = round(100 * forgetting, 6)
    
    bwt = round(100 * metric_logger.backward_transfer, 2)
    fwt =round(100 * metric_logger.forward_transfer, 2)
    info = (
        f"task: {task_id}, "
        f"acc: {acc}, "
        f"avg_acc: {avg_acc}, "
        f"acc_per_task: {acc_per_task}, "
        f"forgetting: {forgetting}, "
        f"bwt: {bwt}, "
        f"fwt: {fwt}"
    )
    print(info)
    
    acc_list.append(100 * metric_logger.accuracy)
    with open(log_path, 'a+') as f:
        f.write(json.dumps({
            'task': task_id,
            'acc': acc,
            'avg_acc': avg_acc,
            'acc_per_task': acc_per_task,
            'forgetting': forgetting,
            'bwt': bwt,
            'fwt': fwt,
        }) + '\n')
        
    metric_logger.end_task()


def train_text(model, tuner, task_id, cfg, trainset, texts):
    model.train()
    tuner.train()
    ### laoding dataset
    train_loader = DataLoader(trainset, batch_size=cfg.batch_size, 
                              shuffle=True, num_workers=cfg.workers,
                              worker_init_fn=seed_worker,
                              generator=torch.Generator().manual_seed(0))
    
    train_iter = iter(train_loader)

    EPOCH = 1
    num_batches = len(train_loader)
    total_iterations = EPOCH * num_batches

    for p in model.parameters():
        p.requires_grad = False
    for n, p in tuner.named_parameters():
        if "text_tuner" in n:
            p.requires_grad=True
        else:
            p.requires_grad=False
    
    n_parameters = sum(p.numel() for p in tuner.parameters() if p.requires_grad)
    print('Number of tuned params in tuning Text Encoder:', n_parameters)

    params = [
        p for p in tuner.parameters() if p.requires_grad
    ]

    # optimizer
    optimizer = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = cosine_lr(
        optimizer, cfg.lr, 30, total_iterations
    )

    # start training
    for iteration in tqdm(range(total_iterations + 1)):
        scheduler(iteration)
        try:
            inputs, targets, task_ids = next(train_iter)
        except:
            train_iter = iter(train_loader)
            inputs, targets, task_ids = next(train_iter)

        
        shift = task_id * cfg.increment
        targets -= shift

        inputs, targets = inputs.cuda(), targets.cuda()
        
        logits_per_image, _ = model(inputs, texts, tuner)
                
        loss = F.cross_entropy(logits_per_image, targets, label_smoothing=cfg.ls)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def train_image(model, tuner, task_id, cfg, trainset, texts):
    model.train()
    tuner.train()
    ### laoding dataset
    train_loader = DataLoader(trainset, batch_size=cfg.batch_size, 
                              shuffle=True, num_workers=cfg.workers,
                              worker_init_fn=seed_worker,
                              generator=torch.Generator().manual_seed(0))
                              
    train_iter = iter(train_loader)

    EPOCH = 1
    num_batches = len(train_loader)
    total_iterations = EPOCH * num_batches

    for p in model.parameters():
        p.requires_grad = False
    for n, p in tuner.named_parameters():
        if "image_tuner" in n or "task_gate" in n:
            p.requires_grad=True
        else:
            p.requires_grad=False
    
    n_parameters = sum(p.numel() for p in tuner.parameters() if p.requires_grad)
    print('Number of tuned params in tuning Image Encoder:', n_parameters)

    params = [
        p for p in tuner.parameters() if p.requires_grad
    ]

    # optimizer
    optimizer = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = cosine_lr(
        optimizer, cfg.lr, 30, total_iterations
    )

    # start training
    for iteration in tqdm(range(total_iterations + 1)):
        scheduler(iteration)
        try:
            inputs, targets, task_ids = next(train_iter)
        except:
            train_iter = iter(train_loader)
            inputs, targets, task_ids = next(train_iter)

        
        shift = task_id * cfg.increment
        targets -= shift

        inputs, targets = inputs.cuda(), targets.cuda()
        
        logits_per_image, _ = model(inputs, texts, tuner)

        loss = F.cross_entropy(logits_per_image, targets, label_smoothing=cfg.ls)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        