import os
from tqdm import tqdm, trange

import torch
from torch.nn import functional as F
from torch.optim import SGD, lr_scheduler
import torch.nn as nn

from util.util import info
from util.eval_util import AverageMeter
from models import vision_transformer as vits
from models.l2p_utils.vision_transformer import vit_base_patch16_224_dino

import numpy as np


import os
import pickle 
from collections import defaultdict
import random
from toolkit import ContrastiveLoss, real_time_eval, post_eval


import models.TTD_l2p as T
from collections import deque

import matplotlib.pyplot as plt
import seaborn as sns
import ptitprince as pt
import pandas as pd
import time



class Memory:
    def initialize_memory(self):
        self = self
        self.knownclass = 70
        self.totalclass = 100
        self.addclass = 0


        self.old_classes = list(range(self.knownclass))  
        self.new_classes = list(range(self.knownclass,self.totalclass+self.addclass))
        self.all_classes = list(range(self.totalclass+self.addclass))
        self.data = None

        self.count = {k:0 for k in self.all_classes}
        self.mem_per_cls=10


        self.memory = {k: deque(maxlen=self.mem_per_cls) for k in self.all_classes}
        self.memory_feature = {k: deque(maxlen=self.mem_per_cls) for k in self.all_classes}


        self.discovered_class = self.knownclass - 1
        self.pred_count = {k: self.totalclass for k in self.old_classes}

        self.centroids = {k: torch.zeros(768, dtype=torch.float32) for k in range(self.totalclass+self.addclass)}

        self.samples_ = {}
        self.images_ = {}

        centroids_filepath='centroids.pkl'

        with open(centroids_filepath, 'rb') as f:
            loaded_centroids = pickle.load(f)
            for k in range(len(loaded_centroids)):
                self.centroids[k] = torch.tensor(loaded_centroids[k], dtype=torch.float32)

        

        for batch in tqdm(self.unlabeled_test_data['default'], desc='Batches', leave=False, bar_format="{desc}{percentage:3.0f}%|{bar}{r_bar}", ncols=80):
            images, labels, uq_idxs, mask_lab = batch
            with torch.no_grad():
                labels = labels.cpu().numpy()

                self.args.ttd_model == 'TTD_L2P_known_K'
                dino_features = self.original_model(images.cuda())['pre_logits']
                features = self.model(images.cuda(), task_id=self.stage_i, cls_features=dino_features)['x'][:, 0]
                features = features.cpu()

            for x, y, z in zip(features, labels, images):
                y = int(y.item())  
                if y not in self.samples_:
                    self.samples_[y] = []
                    self.images_[y] = []
                self.samples_[y].append(x)
                self.images_[y].append(z)

                if len(self.samples_[y]) > 10000:
                    self.samples_[y].pop(0)
                    self.images_[y].pop(0)


        for label in self.old_classes:
            if label in self.samples_:
                features = self.samples_[label]
                images = self.images_[label]

                shuffled_indices = torch.randperm(len(features))
                selected_indices = shuffled_indices[:self.mem_per_cls]
                
            for idx in selected_indices:
                self.memory[label].append(images[idx].cpu())

                         

class TTD_simple:
    def __init__(self, args, model, ori_model, projection_head, adapter, data):
        super().__init__()
        self.args = args
        self.discovered_class = 70
        self.model = model
        self.original_model = ori_model
        self.projection_head = projection_head
        self.data = data
        self.adapter = adapter

        self.stored_features = []
        self.stored_labels = []


        self.aux_count = 0
        self.record = defaultdict(float)
        self.cen_set = defaultdict(list)
        self.t = 1
        self.cos_count = 0
        self.lsh_count = 0

        self.cluster_labels = {i: [] for i in range(self.discovered_class)}

        self.optimizer = SGD(
            list(self.projection_head.parameters()) + list(self.model.parameters()), 
            lr=self.args.base_lr, 
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay
        )
        
        self.exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=self.args.epochs,
            eta_min=self.args.base_lr * 1e-3,
        )


    def predict_and_discover(self,unlabeled_test_data, stage_i, TTT=True, replay=True, self_correction=True):
        self.stage_i = stage_i
        self.model.eval()


        if not hasattr(self, 'adapter_optimizer'):
            params = list(self.adapter.parameters())
            self.adapter_optimizer = torch.optim.SGD(params, lr=0.01, weight_decay=1e-4)

        k = 1
        from collections import deque
        pred_eval = deque(maxlen=10*self.args.batch_size)
        label_eval = deque(maxlen=10)


        # post_eval(self, self.args, self.data, unlabeled_test_data)
        # exit()
        print("############ pre_eval finished ##############")
        batchnum = 1
        for batch in tqdm(unlabeled_test_data['default'], desc='Batches', leave=False, bar_format="{desc}{percentage:3.0f}%|{bar}{r_bar}", ncols=80):
            T3 = time.perf_counter()
            ############ 1. forward and obtain features ###################
            inputs, labels, uq_idxs, mask_lab = batch 

            # with torch.no_grad():
            self.args.ttd_model == 'TTD_L2P_known_K'
            dino_features = self.original_model(inputs.cuda())['pre_logits']
            raw_feats = self.model(inputs.cuda(), task_id=self.stage_i, cls_features=dino_features)['x'][:, 0]
            
            raw_feats_norm = torch.nn.functional.normalize(raw_feats, dim=-1).cuda()
            # feats = raw_feats_norm
            feats = self.adapter(raw_feats.cpu())

            feats = torch.nn.functional.normalize(feats, dim=-1).cuda()

            ############ 2. predict and discover novel class ##############
            preds = self.predict_and_discover_with_cosine_similarity(inputs, feats, labels, replay=True)
            batchnum = batchnum + 1

            pred_eval.extend(preds)
            label_eval.append(labels)

            if k > 9:
                all_preds = list(pred_eval)  
                all_labels = torch.cat(list(label_eval))  
                real_time_eval(self.data, self.record, pred_eval, all_labels)
                
            k = k + 1

            ############ 3. test time training ############################
            if TTT:
                self.test_time_training(replay, feats,raw_feats_norm, preds)
                    
            if self_correction:
                if k > 9 and (k % 2 == 0):
                    self.self_memory_correction()

            del inputs, labels, uq_idxs, mask_lab, feats
            torch.cuda.empty_cache()


        ############ 5. post evaluation ####################
        post_eval(self, self.args, self.data, unlabeled_test_data)


    def test_time_training(self, replay, feats, raw_feats, preds):
        self.projection_head.eval()
        self.adapter.train()


        for label in self.data.memory.keys():
            self.data.memory[label] = [tensor.cpu() for tensor in self.data.memory[label]]

        max_samples_per_label = 2
        max_total_replay_samples = 20

        if replay:
            replay_samples, replay_labels = [], []
            for label, samples in self.data.memory.items():
                if label in self.data.old_classes:
                    limited_samples = random.sample(samples, min(max_samples_per_label, len(samples))) 
                    replay_samples.extend(limited_samples)
                    replay_labels.extend([label] * len(limited_samples))

            if len(replay_samples) > max_total_replay_samples:
                combined = list(zip(replay_samples, replay_labels))
                random.shuffle(combined)
                replay_samples, replay_labels = zip(*combined[:max_total_replay_samples])
                replay_samples = list(replay_samples)
                replay_labels = list(replay_labels)

            replay_samples_tensor = [replay_tensor.to('cuda') for replay_tensor in replay_samples]                    
            replay_samples_tensor = torch.stack(replay_samples)
            replay_labels = list(replay_labels)

            self.args.ttd_model == 'TTD_L2P_known_K'
            dino_features = self.original_model(replay_samples_tensor.cuda())['pre_logits']
            replay_feats = self.model(replay_samples_tensor.cuda(), task_id=self.stage_i, cls_features=dino_features)['x'][:, 0]
            
            norm_replayfeats = torch.nn.functional.normalize(replay_feats, dim=-1).cuda()
            ada_replayfeats = self.adapter(replay_feats.cpu())
            raw_feats = torch.concat([raw_feats, norm_replayfeats], dim= 0)
            ada_replayfeats = torch.nn.functional.normalize(ada_replayfeats, dim=-1).cuda()

            feats = torch.concat([feats, ada_replayfeats], dim= 0)
            preds = preds + replay_labels
        

        old_mask = [p in self.data.old_classes for p in preds]
        new_mask = [not m for m in old_mask]
        if sum(old_mask) > 0:
            print("sum(old_mask)", sum(old_mask))
            adapted_old_feats = feats[old_mask]
            raw_feats_old = raw_feats[old_mask]
            loss_old = F.mse_loss(adapted_old_feats, raw_feats_old.detach())
        else:
            loss_old = 0

        if sum(new_mask) > 0:
            new_feats = feats[new_mask]
            new_labels = [pred for pred, mask in zip(preds, new_mask) if mask]

            loss_new = ContrastiveLoss(0.5)(new_feats, new_labels)
        else:
            loss_new = 0

        total_loss = loss_old + loss_new


        self.adapter_optimizer.zero_grad()
        total_loss.backward()
        self.adapter_optimizer.step()

    


    def self_memory_correction(self):
        replay_samples, replay_labels, original_hash = [], [], []

        for i in range(50):
            label = random.randint(self.data.knownclass,self.data.discovered_class)
            idx = random.randint(0, len(self.data.memory[label])-1)
            replay_samples.append(self.data.memory[label][idx])
            replay_labels.extend([label])

            if(len(self.data.memory[label]) > 1):
                del self.data.memory[label][idx]
        
        replay_samples = list(replay_samples)
        replay_labels = list(replay_labels)

        replay_samples_tensor = [replay_tensor.to('cuda') for replay_tensor in replay_samples]                    
        replay_samples_tensor = torch.stack(replay_samples)
        replay_labels = list(replay_labels)

        self.args.ttd_model == 'TTD_L2P_known_K'
        dino_features = self.original_model(replay_samples_tensor.cuda())['pre_logits']
        updated_features = self.model(replay_samples_tensor.cuda(), task_id=self.stage_i, cls_features=dino_features)['x'][:, 0]

        raw_feats_norm = torch.nn.functional.normalize(updated_features, dim=-1).cuda()
        feats = self.adapter(updated_features.cpu())
        feats = torch.nn.functional.normalize(feats, dim=-1).cuda()
        preds = self.predict_and_discover_with_cosine_similarity(replay_samples_tensor, feats, replay_labels, replay=True)
        self.test_time_training3(False, feats,raw_feats_norm, preds)

    def predict_and_discover_with_cosine_similarity(self, inputs, feats, labels, replay=False, only_test=False, threshold=0.7):
        preds = []
        centroid_cache = {}
        centroid_cache_num = {}
        
        for x, feat, label in zip(inputs, feats, labels):
            feat_norm = torch.norm(feat, p=2)
            similarities = {k: torch.dot(torch.tensor(self.data.centroids[k]).to('cuda'), feat) / 
                                (torch.norm(torch.tensor(self.data.centroids[k]).to('cuda'), p=2) * feat_norm) 
                            for k in self.data.centroids.keys()}
            pred = max(similarities, key=similarities.get)
            max_similarity = similarities[pred].item()

            if only_test:
                preds.append(pred)
                continue

            if max_similarity >= threshold or self.data.discovered_class >= self.data.totalclass - 1 + self.data.addclass:
                if replay:
                    if(pred > self.data.knownclass):
                        self.data.memory[pred].append(x)
                        self.data.memory_feature[pred].append(feat)
            else:
                self.data.discovered_class += 1
                pred = self.data.discovered_class
                self.data.centroids[pred] = feat
                if replay:
                    self.data.count[pred] = 1
                    self.data.memory[pred].append(x)
                    self.data.memory_feature[pred].append(feat)
            preds.append(pred)

            if pred not in centroid_cache.keys():
                centroid_cache[pred] = feat
                centroid_cache_num[pred] = 1
            else:
                centroid_cache[pred] = centroid_cache[pred] + feat
                centroid_cache_num[pred] += 1

        if not only_test:
            for key in centroid_cache.keys():
                if key in self.data.old_classes:
                    self.data.centroids[key] = 1 * torch.tensor(self.data.centroids[key]).to('cuda') + 0 * (centroid_cache[key] / centroid_cache_num[key])
            for idx in range(self.data.knownclass, self.data.discovered_class + 1):
                features = list(self.data.memory_feature[idx])
                features_tensor = torch.stack(features)
                centroid = torch.mean(features_tensor, dim=0)
                self.data.centroids[idx] = centroid.to('cpu')
        return preds