import os
import sys
import types
from copy import deepcopy
from datetime import datetime
from itertools import chain

import numpy as np

import torch.nn as nn
from torch.utils.data import DataLoader
from datasets.cont_data import *

from common import parse_args
from utils.utils import *
from networks.net import Net

import timm

if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"

    args = parse_args()
    args.logger = Logger(args, args.folder)
    args.logger.now()

    # Assign None to feature extractors
    args.model_clip, args.model_vit = None, None

    if args.dynamic is not None:
        args.n_components = args.dynamic

    # np.random.seed(args.seed)
    args.device = device

    args.logger.print('\n\n',
                        os.uname()[1] + ':' + os.getcwd(),
                        'python', ' '.join(sys.argv),
                      '\n\n')

    args.num_cls_per_task = int(args.total_cls // args.n_tasks)
    args.logger.print('\n', args, '\n')

    train_data, test_data = get_data(args)

    if args.task_type == 'standardCL_randomcls':
        task_list = generate_random_cl(args)
        train_data = StandardCL(train_data, args, task_list)
        test_data = StandardCL(test_data, args, task_list)

    args.sup_labels = []
    for task in task_list:
        args.logger.print(task)
        for name in task[0]:
            if name not in args.sup_labels:
                args.sup_labels.append(name)

    ############## transformer; Deit or ViT ############
    if 'adapter' in args.model:
        if 'vitadapter' in args.model:
            model_type = 'vit_base_patch16_224'
            args.in_dim = 768
            from networks.my_vit_hat import vit_base_patch16_224 as transformer
        elif 'deitadapter' in args.model:
            model_type = 'deit_small_patch16_224'
            args.in_dim = 384
            from networks.my_vit_hat import deit_small_patch16_224 as transformer
        
        num_classes = args.num_cls_per_task
        args.net = transformer(pretrained=True, num_classes=num_classes, latent=args.adapter_latent, args=args).to(device)
        
        if args.distillation:
            teacher = timm.create_model(model_type, pretrained=False, num_classes=num_classes).cuda()

        if 'deitadapter' in args.model:
            load_deit_pretrain(args, args.net, args.n_pre_cls)

        if args.model == 'vitadapter_ewt' or args.model == 'deitadapter_ewt':
            args.model_clip, args.clip_init = None, None
            from apprs.vitadapter_ewt import ViTAdapterEWT as Model

    args.criterion = Criterion(args, args.net)
    model = Model(args)
    # model = Model(args.net, nn.CrossEntropyLoss(), args, transforms.ToTensor())

    if args.distillation:
        if args.model in ['vitadapter', 'clipadapter', 'clipadapter_hat']:
            args.logger.print("Load teacher")
            model.teacher = teacher
        if args.model in ['clipadapter', 'clipadapter_hat']:
            args.logger.print("Load teacher net")
            model.teacher_net = teacher_net

    from pipeline import Pipeline
    pipeline = Pipeline(task_list, args, train_data, test_data, model)

    if args.train_clf_pree_id is None and not args.test_task_id:
        args.logger.print("\nTraining starts\n")
        pipeline.train_all()
    elif all([
            args.train_clf_pree_id is None,
            args.test_task_id,
            not args.test_pree,
            not args.test_joint
        ]):
        args.logger.print("\nTesting starts\n")
        for task_id in range(args.test_task_id + 1):
            pipeline.load_task_MD_stats(task_id)
            pipeline.preprocess_task(task_id)
            pipeline.load_train_step(task_id)
            pipeline.load_model_step(task_id)

            if args.report_auc_at_each_update:
                for p_task_id in range(task_id):
                    pipeline.test_auc(p_task_id, epoch=args.n_epochs)
            pipeline.test_auc(task_id, epoch=args.n_epochs)
            pipeline.test_all(task_id, pipeline.test_loaders[:task_id + 1])
            args.logger.print()
    elif all([
            args.train_clf_pree_id is not None, # just set any value. This is just for simplicity of code.
            args.train_single_head_id is not None,
            not args.test_single_head,
        ]):
        args.logger.print("\nTraining regular hat single head\n")
               
        for task_id in range(args.train_single_head_id + 1):
            pipeline.load_task_MD_stats(task_id)
            pipeline.preprocess_task(task_id)
            pipeline.load_train_step(task_id)
            pipeline.load_model_step(task_id)

            if args.finetune_clf_using_real_pseudo_features_after_train_id is not None:
                assert not args.no_pree
                checkpoint = custom_load(os.path.join(args.load_path,
                                        f'finetune_clf_using_real_pseudo_{task_id}'))
                custom_load_state_dict(args, model.net, checkpoint)
                assert 'real_pseudo' in args.single_head_model_name

            original_head = deepcopy(model.net.head)

            if task_id >= 1:
                n_cls = (task_id + 1) * args.num_cls_per_task

                if args.generate_ood:
                    n_cls += 1

                head = nn.Linear(args.in_dim, n_cls).to(args.device)
                model.net.head = head

                def forward_classifier(self, t, x, normalize=False):
                    if normalize:
                        x = x / x.norm(dim=-1, keepdim=True)
                        unit_w = self.head.weight / self.head.weight.norm(dim=-1, keepdim=True)
                        x = 100 * x @ unit_w.T
                    else:
                        x = self.head(x)
                    return x

                model.net.forward_classifier = types.MethodType(forward_classifier, model.net)

                checkpoint = custom_load(os.path.join(args.load_path, f'model_task_{task_id}'))
                for k, v in checkpoint['state_dict'].items():
                    if 'head' in k:
                        splitted = k.split('.')
                        t_id, param = int(splitted[1]), splitted[2]
                        if param == 'weight':
                            model.net.head.weight.data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task] = v[:args.num_cls_per_task].to(args.device)
                        elif param == 'bias':
                            model.net.head.bias.data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task] = v[:args.num_cls_per_task].to(args.device)
                        else:
                            raise NotImplementedError()

                # single head
                if task_id > 1:
                    new_head = custom_load(os.path.join(args.logger.dir(),
                                args.single_head_model_name + f'_{task_id - 1}'))['state_dict']
                    args.logger.print("Replacing the parameters")
                    for t_id in range(task_id):
                        model.net.head.weight.data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task] = new_head['head.weight'].data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task].to(args.device)
                        if args.generate_ood:
                            model.net.head[t_id].weight.data = torch.cat([model.net.head[t_id].weight, new_head['head.weight'][-1:]]).data
                        model.net.head.bias.data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task] = new_head['head.bias'].data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task].to(args.device)
                        if args.generate_ood:
                            model.net.head[t_id].bias.data = torch.cat([model.net.head[t_id].bias, new_head['head.bias'][-1:]]).data

                pipeline.single_head(task_id)

                new_head = deepcopy(model.net.head)
                model.net.head = original_head
                for t_id in range(task_id + 1):
                    model.net.head[t_id].weight.data = new_head.weight.data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task]
                    if args.generate_ood:
                        model.net.head[t_id].weight.data = torch.cat([model.net.head[t_id].weight, new_head.weight[-1:]]).data
                    model.net.head[t_id].bias.data = new_head.bias.data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task]
                    if args.generate_ood:
                        model.net.head[t_id].bias.data = torch.cat([model.net.head[t_id].bias, new_head.bias[-1:]]).data
                def forward_classifier(self, t, x, normalize=False):
                    if normalize:
                        x = x / x.norm(dim=-1, keepdim=True)
                        unit_w = self.head[t].weight / self.head[t].weight.norm(dim=-1, keepdim=True)
                        x = 100 * x @ unit_w.T
                    else:
                        x = self.head[t](x)
                    return x
                model.net.forward_classifier = types.MethodType(forward_classifier, model.net)

            pipeline.test_auc(task_id, epoch=args.n_epochs)
            pipeline.test_all(task_id, pipeline.test_loaders[:task_id + 1])
    elif all([
            args.train_clf_pree_id is not None, # just set any value. This is just for simplicity of code.
            args.train_single_head_id is not None,
            args.test_single_head
        ]):
        args.logger.print(f"\nTesting regular hat single head: {args.single_head_model_name}\n")
               
        for task_id in range(args.train_single_head_id + 1):
            pipeline.preprocess_task(task_id)
            pipeline.load_train_step(task_id)
            pipeline.load_model_step(task_id)
            pipeline.load_task_MD_stats(task_id)

            if task_id >= 1:
                new_head = custom_load(os.path.join(args.load_path,
                            args.single_head_model_name + f'_{task_id}'))['state_dict']
                for t_id in range(task_id + 1):
                    model.net.head[t_id].weight.data = new_head['head.weight'].data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task]
                    if args.generate_ood:
                        model.net.head[t_id].weight.data = torch.cat([model.net.head[t_id].weight, new_head['head.weight'][-1:]]).data
                    model.net.head[t_id].bias.data = new_head['head.bias'].data[t_id * args.num_cls_per_task:(t_id + 1) * args.num_cls_per_task]
                    if args.generate_ood:
                        model.net.head[t_id].bias.data = torch.cat([model.net.head[t_id].bias, new_head['head.bias'][-1:]]).data

            if args.report_auc_at_each_update:
                for p_task_id in range(task_id):
                    pipeline.test_auc(p_task_id, epoch=args.n_epochs)
            pipeline.test_auc(task_id, epoch=args.n_epochs)
            pipeline.test_all(task_id, pipeline.test_loaders[:task_id + 1])
